Update model_base.py

This commit is contained in:
Subarasheese 2026-04-09 20:58:45 -03:00 committed by GitHub
parent 8e04804d5e
commit bd27c7201a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch import torch
import logging import logging
import comfy.ldm.lightricks.av_model import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.cascade.stage_b import StageB
@ -51,6 +52,7 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15 import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -76,6 +78,7 @@ class ModelType(Enum):
FLUX = 8 FLUX = 8
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -108,6 +111,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS: elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -282,6 +287,12 @@ class BaseModel(torch.nn.Module):
return data return data
return None return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
concat_cond = self.concat_cond(**kwargs) concat_cond = self.concat_cond(**kwargs)
@ -880,7 +891,7 @@ class Flux(BaseModel):
return torch.cat((image, mask), dim=1) return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return kwargs["pooled_output"] return kwargs.get("pooled_output", None)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
@ -922,6 +933,26 @@ class Flux(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out return out
class LongCatImage(Flux):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def encode_adm(self, **kwargs):
return None
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out.pop('guidance', None)
return out
class Flux2(Flux): class Flux2(Flux):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
@ -971,6 +1002,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None: if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -988,10 +1023,14 @@ class LTXAV(BaseModel):
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
device = kwargs["device"]
if attention_mask is not None: if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@ -1019,6 +1058,14 @@ class LTXAV(BaseModel):
if latent_shapes is not None: if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1229,6 +1276,11 @@ class Lumina2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out return out
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class WAN21(BaseModel): class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
@ -1336,6 +1388,11 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21): class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@ -1388,6 +1445,11 @@ class WAN21_HuMo(WAN21):
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21): class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@ -1405,6 +1467,13 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21): class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@ -1441,6 +1510,11 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape out['reference_motion'] = reference_motion.shape
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21): class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
@ -1462,6 +1536,50 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class WAN21_SCAIL(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents")
self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
pose_latents = self.process_latent_in(pose_latents)
pose_mask = torch.ones_like(pose_latents[:, :4])
pose_latents = torch.cat([pose_latents, pose_mask], dim=1)
out['pose_latents'] = comfy.conds.CONDRegular(pose_latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out
class Hunyuan3Dv2(BaseModel): class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
@ -1547,6 +1665,24 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out return out
def ace_audio_cover_strength_wrapper(strength, executor, x, timestep, model_options={}, seed=None):
sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None)
if sample_sigmas is not None:
current_sigma = float(timestep.max())
max_sigma = float(sample_sigmas[0])
min_sigma = float(sample_sigmas[-1])
sigma_range = max_sigma - min_sigma
if sigma_range > 0:
progress = 1.0 - (current_sigma - min_sigma) / sigma_range
if progress >= strength:
conds = model_options.get('conds', None)
if conds is not None:
for cond_list in conds.values():
for cond in cond_list:
if 'model_conds' in cond and 'is_covers' in cond['model_conds']:
cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False)
return executor(x, timestep, model_options, seed)
class ACEStep15(BaseModel): class ACEStep15(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
@ -1589,32 +1725,17 @@ class ACEStep15(BaseModel):
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
audio_cover_strength = kwargs.get('audio_cover_strength', 1.0) audio_cover_strength = max(0.0, min(1.0, float(kwargs.get('audio_cover_strength', 1.0))))
is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False
if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None: if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None:
if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'): if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'):
_strength = audio_cover_strength import functools
def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None): bound_wrapper = functools.partial(ace_audio_cover_strength_wrapper, audio_cover_strength)
sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None)
if sample_sigmas is not None:
current_sigma = float(timestep.max())
max_sigma = float(sample_sigmas[0])
min_sigma = float(sample_sigmas[-1])
sigma_range = max_sigma - min_sigma
if sigma_range > 0:
progress = 1.0 - (current_sigma - min_sigma) / sigma_range
if progress >= _strength:
conds = model_options.get('conds', None)
if conds is not None:
for cond_list in conds.values():
for cond in cond_list:
if 'model_conds' in cond and 'is_covers' in cond['model_conds']:
cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False)
return executor(x, timestep, model_options, seed)
self.current_patcher.add_wrapper_with_key( self.current_patcher.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.PREDICT_NOISE, comfy.patcher_extension.WrappersMP.PREDICT_NOISE,
'ace_step_cover_strength', 'ace_step_cover_strength',
audio_cover_strength_wrapper bound_wrapper
) )
return out return out
@ -1869,3 +1990,7 @@ class Kandinsky5Image(Kandinsky5):
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
return None return None
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)