From 822a3ecf7372760eafb3b367fff805d8b0f2fbc5 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Thu, 21 May 2026 11:47:53 -0700 Subject: [PATCH] Note _calc_cond_batch and _calc_cond_batch_multigpu must stay in sync Per review feedback on #7063. The two functions share the conds-by-hooks accumulation, memory-fit batching, and per-chunk output aggregation; the multigpu variant adds per-device scheduling, .to(device) placement, per-device patcher/control lookup, and thread-pool dispatch around the inner loop. Documenting the relationship without extracting helpers -- extraction can land after the initial worksplit-multigpu release once both paths have settled. Amp-Thread-ID: https://ampcode.com/threads/T-019e4a00-fe3d-76bd-a2f2-a8c8c4040082 Co-authored-by: Amp --- comfy/samplers.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy/samplers.py b/comfy/samplers.py index 8bfc42bdb..6fd0387d5 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -218,6 +218,9 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc return executor.execute(model, conds, x_in, timestep, model_options) def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic + # (hooked_to_run accumulation, memory-fit batching, per-chunk output + # aggregation) is duplicated there with per-device scheduling layered on top. if 'multigpu_clones' in model_options: return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options) out_conds = [] @@ -353,6 +356,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens return out_conds def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + # NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks + # accumulation, memory-fit batching, and output aggregation, but adds a + # per-device scheduler, per-device patcher/control lookup, tensor .to(device) + # placement, and MultiGPUThreadPool dispatch around the inner loop. out_conds = [] out_counts = [] # separate conds by matching hooks