diff --git a/comfy/conds.py b/comfy/conds.py index f2564e7ef..5af3e93ea 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -1,6 +1,7 @@ import torch import math import comfy.utils +import logging class CONDRegular: @@ -16,6 +17,9 @@ class CONDRegular: def can_concat(self, other): if self.cond.shape != other.cond.shape: return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device, skipping concat.") + return False return True def concat(self, others): @@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular): diff = mult_min // min(s1[1], s2[1]) if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much return False + if self.cond.device != other.cond.device: + logging.warning("WARNING: conds not on same device: skipping concat.") + return False return True def concat(self, others): diff --git a/comfy/model_base.py b/comfy/model_base.py index 3a9c031ea..a06686436 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -162,7 +162,7 @@ class BaseModel(torch.nn.Module): xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) + xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype() @@ -401,7 +401,7 @@ class SD21UNCLIP(BaseModel): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: - return torch.zeros((1, self.adm_channels)) + return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) @@ -615,9 +615,11 @@ class IP2P: if image is None: image = torch.zeros_like(noise) + else: + image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) @@ -696,7 +698,7 @@ class StableCascade_B(BaseModel): #size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) - out["effnet"] = comfy.conds.CONDRegular(prior) + out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out @@ -1161,10 +1163,10 @@ class WAN21_Vace(WAN21): vace_frames_out = [] for j in range(len(vace_frames)): - vf = vace_frames[j].clone() + vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) - vf = torch.cat([vf, mask[j]], dim=1) + vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1)