diff --git a/comfy/samplers.py b/comfy/samplers.py index f0d67cb7e..a99af5217 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -433,7 +433,11 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) * 1.5 < free_memory: + cond_shapes = collections.defaultdict(list) + for tt in batch_amount: + for k, v in to_run[tt][0].conditioning.items(): + cond_shapes[k].append(v.size()) + if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory: to_batch = batch_amount break