mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-26 06:10:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
fc93a6f534
@ -853,6 +853,11 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||||
"""DPM-Solver++(3M) SDE."""
|
"""DPM-Solver++(3M) SDE."""
|
||||||
@ -925,6 +930,16 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
|
|||||||
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_dpmpp_2m_sde_heun_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='heun'):
|
||||||
|
if len(sigmas) <= 1:
|
||||||
|
return x
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
|
||||||
|
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
|
||||||
|
return sample_dpmpp_2m_sde_heun(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
|
||||||
if len(sigmas) <= 1:
|
if len(sigmas) <= 1:
|
||||||
|
|||||||
@ -158,7 +158,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_i):
|
if i < len(control_i):
|
||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img[:, :add.shape[1]] += add
|
||||||
|
|
||||||
if img.dtype == torch.float16:
|
if img.dtype == torch.float16:
|
||||||
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
img = torch.nan_to_num(img, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@ -189,7 +189,7 @@ class Flux(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, txt.shape[1] :, ...] += add
|
img[:, txt.shape[1] : txt.shape[1] + add.shape[1], ...] += add
|
||||||
|
|
||||||
img = img[:, txt.shape[1] :, ...]
|
img = img[:, txt.shape[1] :, ...]
|
||||||
|
|
||||||
|
|||||||
@ -459,7 +459,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
if i < len(control_i):
|
if i < len(control_i):
|
||||||
add = control_i[i]
|
add = control_i[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
hidden_states += add
|
hidden_states[:, :add.shape[1]] += add
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|||||||
@ -1255,6 +1255,7 @@ class WanModel_S2V(WanModel):
|
|||||||
audio_emb = None
|
audio_emb = None
|
||||||
|
|
||||||
# embeddings
|
# embeddings
|
||||||
|
bs, _, time, height, width = x.shape
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
if control_video is not None:
|
if control_video is not None:
|
||||||
x = x + self.cond_encoder(control_video)
|
x = x + self.cond_encoder(control_video)
|
||||||
@ -1272,11 +1273,12 @@ class WanModel_S2V(WanModel):
|
|||||||
if reference_latent is not None:
|
if reference_latent is not None:
|
||||||
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
|
||||||
ref = ref.flatten(2).transpose(1, 2)
|
ref = ref.flatten(2).transpose(1, 2)
|
||||||
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=30, device=x.device, dtype=x.dtype)
|
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=max(30, time + 9), device=x.device, dtype=x.dtype)
|
||||||
ref = ref + cond_mask_weight[1]
|
ref = ref + cond_mask_weight[1]
|
||||||
x = torch.cat([x, ref], dim=1)
|
x = torch.cat([x, ref], dim=1)
|
||||||
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
freqs = torch.cat([freqs, freqs_ref], dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], reference_latent.shape[-3]), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del ref, freqs_ref
|
||||||
|
|
||||||
if reference_motion is not None:
|
if reference_motion is not None:
|
||||||
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
motion_encoded, freqs_motion = self.frame_packer(reference_motion, self)
|
||||||
@ -1286,6 +1288,7 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
t = torch.repeat_interleave(t, 2, dim=1)
|
t = torch.repeat_interleave(t, 2, dim=1)
|
||||||
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
t = torch.cat([t, torch.zeros((t.shape[0], 3), device=t.device, dtype=t.dtype)], dim=1)
|
||||||
|
del motion_encoded, freqs_motion
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
@ -1296,7 +1299,6 @@ class WanModel_S2V(WanModel):
|
|||||||
# context
|
# context
|
||||||
context = self.text_embedding(context)
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
|||||||
@ -150,6 +150,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
logging.debug("adm {}".format(self.adm_channels))
|
logging.debug("adm {}".format(self.adm_channels))
|
||||||
self.memory_usage_factor = model_config.memory_usage_factor
|
self.memory_usage_factor = model_config.memory_usage_factor
|
||||||
self.memory_usage_factor_conds = ()
|
self.memory_usage_factor_conds = ()
|
||||||
|
self.memory_usage_shape_process = {}
|
||||||
|
|
||||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
@ -350,8 +351,15 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes = [input_shape]
|
input_shapes = [input_shape]
|
||||||
for c in self.memory_usage_factor_conds:
|
for c in self.memory_usage_factor_conds:
|
||||||
shape = cond_shapes.get(c, None)
|
shape = cond_shapes.get(c, None)
|
||||||
if shape is not None and len(shape) > 0:
|
if shape is not None:
|
||||||
input_shapes += shape
|
if c in self.memory_usage_shape_process:
|
||||||
|
out = []
|
||||||
|
for s in shape:
|
||||||
|
out.append(self.memory_usage_shape_process[c](s))
|
||||||
|
shape = out
|
||||||
|
|
||||||
|
if len(shape) > 0:
|
||||||
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
@ -1204,6 +1212,8 @@ class WAN21_Camera(WAN21):
|
|||||||
class WAN22_S2V(WAN21):
|
class WAN22_S2V(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
|
self.memory_usage_factor_conds = ("reference_latent", "reference_motion")
|
||||||
|
self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1224,6 +1234,17 @@ class WAN22_S2V(WAN21):
|
|||||||
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
out['control_video'] = comfy.conds.CONDRegular(self.process_latent_in(control_video))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||||
|
|
||||||
|
reference_motion = kwargs.get("reference_motion", None)
|
||||||
|
if reference_motion is not None:
|
||||||
|
out['reference_motion'] = reference_motion.shape
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN22(BaseModel):
|
class WAN22(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
2
comfy/samplers.py
Normal file → Executable file
2
comfy/samplers.py
Normal file → Executable file
@ -729,7 +729,7 @@ class Sampler:
|
|||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||||
|
|
||||||
|
|||||||
@ -97,6 +97,9 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LokrDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -89,6 +89,7 @@ class DiffSynthCnetPatch:
|
|||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
|
||||||
|
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||||
|
|
||||||
def encode_latent_cond(self, image):
|
def encode_latent_cond(self, image):
|
||||||
latent_image = self.vae.encode(image)
|
latent_image = self.vae.encode(image)
|
||||||
@ -106,14 +107,15 @@ class DiffSynthCnetPatch:
|
|||||||
x = kwargs.get("x")
|
x = kwargs.get("x")
|
||||||
img = kwargs.get("img")
|
img = kwargs.get("img")
|
||||||
block_index = kwargs.get("block_index")
|
block_index = kwargs.get("block_index")
|
||||||
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
|
img[:, :self.encoded_image.shape[1]] += (self.model_patch.model.control_block(img[:, :self.encoded_image.shape[1]], self.encoded_image.to(img.dtype), block_index) * self.strength)
|
||||||
kwargs['img'] = img
|
kwargs['img'] = img
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|||||||
@ -920,7 +920,7 @@ class WanSoundImageToVideo(io.ComfyNode):
|
|||||||
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_embed_bucket})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket})
|
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_embed_bucket * 0.0})
|
||||||
|
|
||||||
if ref_image is not None:
|
if ref_image is not None:
|
||||||
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user