diff --git a/comfy/sample.py b/comfy/sample.py index 79ea37e0d..f5b68954d 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -38,8 +38,10 @@ def broadcast_cond(cond, batch, device): copy = [] for p in cond: t = p[0] - if t.shape[0] < batch: - t = torch.cat([t] * batch) + if t.shape[0] != batch: + t_list = [t for _ in range(batch // t.shape[0])] + t_list.append(t[:batch % t.shape[0]]) + t = torch.cat(t_list) t = t.to(device) copy += [[t] + p[1:]] return copy