Skip to main content
Open as a runnable notebook When a batch of inputs is available and several GPUs are idle, the work can be split across them. ToolPool does this automatically. A call wrapped in a with ToolPool(): context is partitioned across every available device, the pieces run in parallel worker processes, and the results are reassembled in the original order. ToolPool is the multi-GPU counterpart to ToolInstance.persist(). persist() is appropriate when one worker on one GPU is enough to amortize the load cost across a batch; ToolPool() is appropriate when that same batch should be spread across several GPUs to reduce wall-clock time. The two mechanisms cooperate: inside a ToolPool block, every worker stays warm for later calls in the same block, so persistence is obtained automatically.
FeatureWhat it does
Transparent interceptionAny tool with an iterable_input_field is auto-partitioned.
Cost-aware schedulingItems are distributed via LPT bin-packing using each tool’s item_cost() estimate.
Built-in persistenceWorkers stay alive across calls within the pool, so reloading happens at most once per GPU.
Device auto-detectionDiscovers all visible GPUs, or accepts an explicit device list.
Automatic dedupDuplicate items are computed once and expanded back to their original positions.
ToolPool is for local inference on hardware under the caller’s control. Under cloud inference, partitioning and parallelism are handled remotely, so ToolPool is not needed.

1. Basic usage (auto-detect GPUs)

The simplest form of ToolPool takes no arguments. Every GPU the process can see joins the pool, and any run_* call inside the block is partitioned across all of them. Results are returned in input order, so the multi-GPU execution is invisible to the caller.
python
import random

from proto_tools.tools.structure_prediction.esmfold import (
    run_esmfold, ESMFoldInput,
)
from proto_tools.utils.tool_pool import ToolPool

# A batch of 48 distinct short peptides to fold across the available GPUs
rng = random.Random(0)
sequences = ["".join(rng.choices("ACDEFGHIKLMNPQRSTVWY", k=15)) for _ in range(48)]
complexes = [[seq] for seq in sequences]  # 48 sequences

with ToolPool():
    output = run_esmfold(ESMFoldInput(complexes=complexes))

# Structures come back in the same order as the input
assert len(output.structures) == len(complexes)
ToolPool fans a 48-sequence batch across 4 GPUs, 12 sequences per GPU On a machine with four GPUs, the 48-sequence batch above is distributed as 12 sequences per GPU. Each worker loads the model once, in parallel, runs its slice, and the pool reassembles the structures into the original order before returning.

2. Restricting or choosing devices

In some cases not every visible GPU should be used: another job may occupy some of the cards, or a benchmark may require a fixed number of devices. Passing gpus=[...] restricts the pool explicitly.
python
with ToolPool(gpus=["cuda:0", "cuda:1"]):
    output = run_esmfold(ESMFoldInput(complexes=complexes))
Passing gpus=[...] restricts the ToolPool; unlisted GPUs stay free for other workloads Only the listed GPUs join the pool, so the same batch is partitioned across two workers instead of four. The unlisted GPUs remain available for other work on the same machine.

3. Persistence within a pool

Once a worker loads inside a ToolPool block, it stays resident for every later call in that same block. The first call pays the model-loading cost on every GPU in the pool; every later call skips the load and runs against workers that are already warm.
python
with ToolPool(gpus=["cuda:0", "cuda:1"]):
    # Cold call: pays model loading on both GPUs
    result1 = run_esmfold(ESMFoldInput(complexes=complexes))

    # other tasks here ...

    # Warm call: workers already loaded
    result2 = run_esmfold(ESMFoldInput(complexes=complexes))

    # other tasks here ...

    result3 = run_esmfold(ESMFoldInput(complexes=complexes))
Workers stay resident across calls within the same ToolPool block; only the first call pays the load cost On exit, all workers are shut down and GPU memory is released. The pool manages the lifecycle automatically.

4. Cost-aware scheduling

Real batches rarely contain items of uniform cost. One protein sequence may be 40 residues and another 800; a longer sequence requires proportionally more compute, so it dominates its partition. Distributing items round-robin makes whichever worker received the long items the bottleneck while the others sit idle. ToolPool avoids this with longest-processing-time-first (LPT) bin-packing. Each tool reports a per-item cost estimate (for structure prediction, for example, the total residue count). ToolPool sorts the batch by descending cost and assigns each item to whichever worker currently has the least total work. As a result, every GPU finishes at approximately the same wall-clock time, regardless of how the input sizes are distributed. Naive round-robin leaves GPUs idle while one bottlenecks; LPT balances finish times across all GPUs This requires no special action. As long as the tool’s item_cost() is reasonable, a mixed batch of short and long inputs is balanced automatically. For a batch of equally sized inputs, scheduling reduces to round-robin.

5. Automatic deduplication

When the same input appears multiple times in a batch, ToolPool computes it once and expands the result back to every position where it appeared. This is transparent, so the returned list always has the same length and order as the input.
python
# 1000 calls, but only 3 unique sequences
sequences = (["MKTLLILAVVAAALA"] * 400
             + ["GAVLTVLLGGLLLA"] * 300
             + ["MGQQPGKVLGDQRR"] * 300)

with ToolPool():
    output = run_esmfold(ESMFoldInput(complexes=[[s] for s in sequences]))

# Only 3 folds were actually run; output.structures has 1000 entries
This matters most for sweeps and sampling schemes that produce many duplicates of a small set of distinct inputs. Without deduplication, the same structure would be folded hundreds of times unnecessarily.

6. When to use ToolPool versus ToolInstance.persist()

ToolPool and ToolInstance.persist() address related but distinct problems:
SituationUse
One GPU, batch of callsToolInstance.persist()
Multiple GPUs, single large batchToolPool()
Multiple GPUs, manual control over which tool runs whereToolInstance.persist_tool(instance_name=...)
Cloud inferenceNeither; the provider handles it
One caveat: ToolPool only accelerates tools that declare an iterable_input_field (for example, complexes for ESMFold, sequences for ESM2). A tool that takes a single indivisible input runs on one GPU regardless of the pool size, because there is nothing to partition.

Configuration reference

python
from proto_tools.utils.tool_pool import ToolPool

# Auto-detect every visible GPU
with ToolPool():
    ...

# Explicit device list
with ToolPool(gpus=["cuda:0", "cuda:2"]):
    ...

Go deeper

For the implementation details behind this guide, consult the developer notes in the proto-tools repository: ToolPool & Parallel ExecutionPool construction, what gets partitioned, LPT scheduling and item_cost, deduplication, and partition failure handling.