From 182f90b5eca2baa25474223759039925b286d562 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:11:53 -0700 Subject: [PATCH 1/5] Lower cond vram use by casting at the same time as device transfer. (#9159) --- comfy/conds.py | 14 +++++++------- comfy/model_base.py | 6 +++--- comfy/samplers.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/conds.py b/comfy/conds.py index 2af2a43a3..f2564e7ef 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -10,8 +10,8 @@ class CONDRegular: def _copy_with(self, cond): return self.__class__(cond) - def process_cond(self, batch_size, device, **kwargs): - return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) + def process_cond(self, batch_size, **kwargs): + return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size)) def can_concat(self, other): if self.cond.shape != other.cond.shape: @@ -29,14 +29,14 @@ class CONDRegular: class CONDNoiseShape(CONDRegular): - def process_cond(self, batch_size, device, area, **kwargs): + def process_cond(self, batch_size, area, **kwargs): data = self.cond if area is not None: dims = len(area) // 2 for i in range(dims): data = data.narrow(i + 2, area[i + dims], area[i]) - return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) + return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size)) class CONDCrossAttn(CONDRegular): @@ -73,7 +73,7 @@ class CONDConstant(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): return self._copy_with(self.cond) def can_concat(self, other): @@ -92,10 +92,10 @@ class CONDList(CONDRegular): def __init__(self, cond): self.cond = cond - def process_cond(self, batch_size, device, **kwargs): + def process_cond(self, batch_size, **kwargs): out = [] for c in self.cond: - out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device)) + out.append(comfy.utils.repeat_to_batch_size(c, batch_size)) return self._copy_with(out) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4556ee138..3a9c031ea 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -109,9 +109,9 @@ def model_sampling(model_config, model_type): def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: - extra = extra.to(dtype=dtype, device=device) + extra = comfy.model_management.cast_to_device(extra, device, dtype) else: - extra = extra.to(device=device) + extra = comfy.model_management.cast_to_device(extra, device, None) return extra @@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module): device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: - context = context.to(dtype=dtype, device=device) + context = comfy.model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: diff --git a/comfy/samplers.py b/comfy/samplers.py index e93d2a315..ad2f40cdc 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in): conditioning = {} model_conds = conds["model_conds"] for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) + conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area) hooks = conds.get('hooks', None) control = conds.get('control', None) From 140ffc7fdc53e810030f060e421c1f528c2d2ab9 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:28:12 -0700 Subject: [PATCH 2/5] Fix broken controlnet from last PR. (#9167) --- comfy/controlnet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 6ed8bd756..988acdb57 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -28,6 +28,7 @@ import comfy.model_detection import comfy.model_patcher import comfy.ops import comfy.latent_formats +import comfy.model_base import comfy.cldm.cldm import comfy.t2i_adapter.adapter @@ -264,12 +265,12 @@ class ControlNet(ControlBase): for c in self.extra_conds: temp = cond.get(c, None) if temp is not None: - extra[c] = temp.to(dtype) + extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device) timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra) return self.control_merge(control, control_prev, output_dtype=None) def copy(self): From 7991341e89cab521441641505ac4b0eea292a829 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:02:40 -0700 Subject: [PATCH 3/5] Various fixes for broken things from earlier PR. (#9168) --- comfy/model_base.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 3a9c031ea..f9591f292 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -401,7 +401,7 @@ class SD21UNCLIP(BaseModel): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: - return torch.zeros((1, self.adm_channels)) + return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"] + return args["pooled_output"].to(device=args["device"]) class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): @@ -615,9 +615,11 @@ class IP2P: if image is None: image = torch.zeros_like(noise) + else: + image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) @@ -696,7 +698,7 @@ class StableCascade_B(BaseModel): #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) - out["effnet"] = comfy.conds.CONDRegular(prior) + out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out @@ -1161,10 +1163,10 @@ class WAN21_Vace(WAN21): vace_frames_out = [] for j in range(len(vace_frames)): - vf = vace_frames[j].clone() + vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) - vf = torch.cat([vf, mask[j]], dim=1) + vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1) From 84f9759424ccbd8de710960c79f0f1d28eef2776 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:20:12 -0700 Subject: [PATCH 4/5] Add some warnings and prevent crash when cond devices don't match. (#9169) --- comfy/conds.py | 7 +++++++ comfy/model_base.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/comfy/conds.py b/comfy/conds.py index f2564e7ef..5af3e93ea 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -16,6 +17,9 @@ class CONDRegular: def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): diff --git a/comfy/model_base.py b/comfy/model_base.py index f9591f292..2db81e244 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -409,7 +409,7 @@ def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:,:1280] else: - return args["pooled_output"].to(device=args["device"]) + return args["pooled_output"] class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): From 03895dea7c4a6cc93fa362cd11ca450217d74b13 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 4 Aug 2025 01:33:04 -0700 Subject: [PATCH 5/5] Fix another issue with the PR. (#9170) --- comfy/model_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2db81e244..a06686436 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,7 +162,7 @@ class BaseModel(torch.nn.Module): xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) + xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype()