mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 20:40:49 +08:00
Merge branch 'master' into dr-support-pip-cm
This commit is contained in:
commit
02317a1f71
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
@ -10,12 +11,15 @@ class CONDRegular:
|
|||||||
def _copy_with(self, cond):
|
def _copy_with(self, cond):
|
||||||
return self.__class__(cond)
|
return self.__class__(cond)
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size))
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond.shape != other.cond.shape:
|
if self.cond.shape != other.cond.shape:
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device, skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
@ -29,14 +33,14 @@ class CONDRegular:
|
|||||||
|
|
||||||
|
|
||||||
class CONDNoiseShape(CONDRegular):
|
class CONDNoiseShape(CONDRegular):
|
||||||
def process_cond(self, batch_size, device, area, **kwargs):
|
def process_cond(self, batch_size, area, **kwargs):
|
||||||
data = self.cond
|
data = self.cond
|
||||||
if area is not None:
|
if area is not None:
|
||||||
dims = len(area) // 2
|
dims = len(area) // 2
|
||||||
for i in range(dims):
|
for i in range(dims):
|
||||||
data = data.narrow(i + 2, area[i + dims], area[i])
|
data = data.narrow(i + 2, area[i + dims], area[i])
|
||||||
|
|
||||||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
|
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size))
|
||||||
|
|
||||||
|
|
||||||
class CONDCrossAttn(CONDRegular):
|
class CONDCrossAttn(CONDRegular):
|
||||||
@ -51,6 +55,9 @@ class CONDCrossAttn(CONDRegular):
|
|||||||
diff = mult_min // min(s1[1], s2[1])
|
diff = mult_min // min(s1[1], s2[1])
|
||||||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
||||||
return False
|
return False
|
||||||
|
if self.cond.device != other.cond.device:
|
||||||
|
logging.warning("WARNING: conds not on same device: skipping concat.")
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def concat(self, others):
|
def concat(self, others):
|
||||||
@ -73,7 +80,7 @@ class CONDConstant(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
@ -92,10 +99,10 @@ class CONDList(CONDRegular):
|
|||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
|
|
||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, **kwargs):
|
||||||
out = []
|
out = []
|
||||||
for c in self.cond:
|
for c in self.cond:
|
||||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
out.append(comfy.utils.repeat_to_batch_size(c, batch_size))
|
||||||
|
|
||||||
return self._copy_with(out)
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ import comfy.model_detection
|
|||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
|
import comfy.model_base
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
@ -264,12 +265,12 @@ class ControlNet(ControlBase):
|
|||||||
for c in self.extra_conds:
|
for c in self.extra_conds:
|
||||||
temp = cond.get(c, None)
|
temp = cond.get(c, None)
|
||||||
if temp is not None:
|
if temp is not None:
|
||||||
extra[c] = temp.to(dtype)
|
extra[c] = comfy.model_base.convert_tensor(temp, dtype, x_noisy.device)
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=comfy.model_management.cast_to_device(context, x_noisy.device, dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype=None)
|
return self.control_merge(control, control_prev, output_dtype=None)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
|
|||||||
@ -109,9 +109,9 @@ def model_sampling(model_config, model_type):
|
|||||||
def convert_tensor(extra, dtype, device):
|
def convert_tensor(extra, dtype, device):
|
||||||
if hasattr(extra, "dtype"):
|
if hasattr(extra, "dtype"):
|
||||||
if extra.dtype != torch.int and extra.dtype != torch.long:
|
if extra.dtype != torch.int and extra.dtype != torch.long:
|
||||||
extra = extra.to(dtype=dtype, device=device)
|
extra = comfy.model_management.cast_to_device(extra, device, dtype)
|
||||||
else:
|
else:
|
||||||
extra = extra.to(device=device)
|
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||||
return extra
|
return extra
|
||||||
|
|
||||||
|
|
||||||
@ -162,7 +162,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = self.model_sampling.calculate_input(sigma, x)
|
xc = self.model_sampling.calculate_input(sigma, x)
|
||||||
|
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
xc = torch.cat([xc] + [c_concat], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@ -174,7 +174,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
device = xc.device
|
device = xc.device
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
if context is not None:
|
if context is not None:
|
||||||
context = context.to(dtype=dtype, device=device)
|
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||||
|
|
||||||
extra_conds = {}
|
extra_conds = {}
|
||||||
for o in kwargs:
|
for o in kwargs:
|
||||||
@ -401,7 +401,7 @@ class SD21UNCLIP(BaseModel):
|
|||||||
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
unclip_conditioning = kwargs.get("unclip_conditioning", None)
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
if unclip_conditioning is None:
|
if unclip_conditioning is None:
|
||||||
return torch.zeros((1, self.adm_channels))
|
return torch.zeros((1, self.adm_channels), device=device)
|
||||||
else:
|
else:
|
||||||
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10)
|
||||||
|
|
||||||
@ -615,9 +615,11 @@ class IP2P:
|
|||||||
|
|
||||||
if image is None:
|
if image is None:
|
||||||
image = torch.zeros_like(noise)
|
image = torch.zeros_like(noise)
|
||||||
|
else:
|
||||||
|
image = image.to(device=device)
|
||||||
|
|
||||||
if image.shape[1:] != noise.shape[1:]:
|
if image.shape[1:] != noise.shape[1:]:
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||||
return self.process_ip2p_image_in(image)
|
return self.process_ip2p_image_in(image)
|
||||||
@ -696,7 +698,7 @@ class StableCascade_B(BaseModel):
|
|||||||
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
|
||||||
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
|
||||||
|
|
||||||
out["effnet"] = comfy.conds.CONDRegular(prior)
|
out["effnet"] = comfy.conds.CONDRegular(prior.to(device=noise.device))
|
||||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -1161,10 +1163,10 @@ class WAN21_Vace(WAN21):
|
|||||||
|
|
||||||
vace_frames_out = []
|
vace_frames_out = []
|
||||||
for j in range(len(vace_frames)):
|
for j in range(len(vace_frames)):
|
||||||
vf = vace_frames[j].clone()
|
vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True)
|
||||||
for i in range(0, vf.shape[1], 16):
|
for i in range(0, vf.shape[1], 16):
|
||||||
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
|
||||||
vf = torch.cat([vf, mask[j]], dim=1)
|
vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1)
|
||||||
vace_frames_out.append(vf)
|
vace_frames_out.append(vf)
|
||||||
|
|
||||||
vace_frames = torch.stack(vace_frames_out, dim=1)
|
vace_frames = torch.stack(vace_frames_out, dim=1)
|
||||||
|
|||||||
@ -89,7 +89,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
|
|||||||
conditioning = {}
|
conditioning = {}
|
||||||
model_conds = conds["model_conds"]
|
model_conds = conds["model_conds"]
|
||||||
for c in model_conds:
|
for c in model_conds:
|
||||||
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area)
|
conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], area=area)
|
||||||
|
|
||||||
hooks = conds.get('hooks', None)
|
hooks = conds.get('hooks', None)
|
||||||
control = conds.get('control', None)
|
control = conds.get('control', None)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user