mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Update model_base.py
This commit is contained in:
parent
8e04804d5e
commit
bd27c7201a
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
from comfy.ldm.cascade.stage_b import StageB
|
||||
@ -51,6 +52,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -76,6 +78,7 @@ class ModelType(Enum):
|
||||
FLUX = 8
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -108,6 +111,8 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.FLOW_COSMOS:
|
||||
c = comfy.model_sampling.COSMOS_RFLOW
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -282,6 +287,12 @@ class BaseModel(torch.nn.Module):
|
||||
return data
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||
Return a sliced cond object, or None to fall through to default handling.
|
||||
Use comfy.context_windows.slice_cond() for common cases."""
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = {}
|
||||
concat_cond = self.concat_cond(**kwargs)
|
||||
@ -880,7 +891,7 @@ class Flux(BaseModel):
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
return kwargs.get("pooled_output", None)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@ -922,6 +933,26 @@ class Flux(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class LongCatImage(Flux):
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
transformer_options = transformer_options.copy()
|
||||
rope_opts = transformer_options.get("rope_options", {})
|
||||
rope_opts = dict(rope_opts)
|
||||
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
|
||||
rope_opts.setdefault("shift_t", 1.0)
|
||||
rope_opts.setdefault("shift_y", pe_len)
|
||||
rope_opts.setdefault("shift_x", pe_len)
|
||||
transformer_options["rope_options"] = rope_opts
|
||||
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out.pop('guidance', None)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@ -971,6 +1002,10 @@ class LTXV(BaseModel):
|
||||
if keyframe_idxs is not None:
|
||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||
@ -988,10 +1023,14 @@ class LTXAV(BaseModel):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
@ -1019,6 +1058,14 @@ class LTXAV(BaseModel):
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
|
||||
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||
if guide_attention_entries is not None:
|
||||
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||
|
||||
ref_audio = kwargs.get("ref_audio", None)
|
||||
if ref_audio is not None:
|
||||
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
|
||||
|
||||
return out
|
||||
|
||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||
@ -1229,6 +1276,11 @@ class Lumina2(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class ZImagePixelSpace(Lumina2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@ -1336,6 +1388,11 @@ class WAN21_Vace(WAN21):
|
||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "vace_context":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN21_Camera(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||
@ -1388,6 +1445,11 @@ class WAN21_HuMo(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_Animate(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||
@ -1405,6 +1467,13 @@ class WAN22_Animate(WAN21):
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "face_pixel_values":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||
if cond_key == "pose_latents":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22_S2V(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||
@ -1441,6 +1510,11 @@ class WAN22_S2V(WAN21):
|
||||
out['reference_motion'] = reference_motion.shape
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "audio_embed":
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
class WAN22(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
@ -1462,6 +1536,50 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
class WAN21_SCAIL(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
|
||||
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
pose_latents = self.process_latent_in(pose_latents)
|
||||
pose_mask = torch.ones_like(pose_latents[:, :4])
|
||||
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
|
||||
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
@ -1547,6 +1665,24 @@ class ACEStep(BaseModel):
|
||||
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
|
||||
return out
|
||||
|
||||
def ace_audio_cover_strength_wrapper(strength, executor, x, timestep, model_options={}, seed=None):
|
||||
sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None)
|
||||
if sample_sigmas is not None:
|
||||
current_sigma = float(timestep.max())
|
||||
max_sigma = float(sample_sigmas[0])
|
||||
min_sigma = float(sample_sigmas[-1])
|
||||
sigma_range = max_sigma - min_sigma
|
||||
if sigma_range > 0:
|
||||
progress = 1.0 - (current_sigma - min_sigma) / sigma_range
|
||||
if progress >= strength:
|
||||
conds = model_options.get('conds', None)
|
||||
if conds is not None:
|
||||
for cond_list in conds.values():
|
||||
for cond in cond_list:
|
||||
if 'model_conds' in cond and 'is_covers' in cond['model_conds']:
|
||||
cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
return executor(x, timestep, model_options, seed)
|
||||
|
||||
class ACEStep15(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
|
||||
@ -1589,32 +1725,17 @@ class ACEStep15(BaseModel):
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
|
||||
audio_cover_strength = kwargs.get('audio_cover_strength', 1.0)
|
||||
audio_cover_strength = max(0.0, min(1.0, float(kwargs.get('audio_cover_strength', 1.0))))
|
||||
is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False
|
||||
|
||||
if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None:
|
||||
if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'):
|
||||
_strength = audio_cover_strength
|
||||
def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None):
|
||||
sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None)
|
||||
if sample_sigmas is not None:
|
||||
current_sigma = float(timestep.max())
|
||||
max_sigma = float(sample_sigmas[0])
|
||||
min_sigma = float(sample_sigmas[-1])
|
||||
sigma_range = max_sigma - min_sigma
|
||||
if sigma_range > 0:
|
||||
progress = 1.0 - (current_sigma - min_sigma) / sigma_range
|
||||
if progress >= _strength:
|
||||
conds = model_options.get('conds', None)
|
||||
if conds is not None:
|
||||
for cond_list in conds.values():
|
||||
for cond in cond_list:
|
||||
if 'model_conds' in cond and 'is_covers' in cond['model_conds']:
|
||||
cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
return executor(x, timestep, model_options, seed)
|
||||
import functools
|
||||
bound_wrapper = functools.partial(ace_audio_cover_strength_wrapper, audio_cover_strength)
|
||||
self.current_patcher.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.PREDICT_NOISE,
|
||||
'ace_step_cover_strength',
|
||||
audio_cover_strength_wrapper
|
||||
bound_wrapper
|
||||
)
|
||||
|
||||
return out
|
||||
@ -1869,3 +1990,7 @@ class Kandinsky5Image(Kandinsky5):
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user