From ae4afd78e06fd44dd181fb42cd42032de1c70dea Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Wed, 18 Feb 2026 01:42:42 -0300 Subject: [PATCH 1/6] Update nodes_ace.py --- comfy_extras/nodes_ace.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 9cf84ab4d..de55080d4 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -50,14 +50,16 @@ class TextEncodeAceStepAudio15(io.ComfyNode): io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("audio_cover_strength", default=1.0, min=0.0, max=1.0, step=0.01, advanced=True, tooltip="Controls how many denoising steps use LM code conditioning. 1.0 = all steps, 0.5 = first half only."), ], outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p, audio_cover_strength) -> io.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values(conditioning, {"audio_cover_strength": audio_cover_strength}) return io.NodeOutput(conditioning) From 8e04804d5e7bf79e370cf75653c0960193e6c175 Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Wed, 18 Feb 2026 01:43:25 -0300 Subject: [PATCH 2/6] Update model_base.py --- comfy/model_base.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9dcef8741..9e6d48429 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1588,6 +1588,35 @@ class ACEStep15(BaseModel): refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2) out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + + audio_cover_strength = kwargs.get('audio_cover_strength', 1.0) + is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False + if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None: + if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'): + _strength = audio_cover_strength + def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None): + sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None) + if sample_sigmas is not None: + current_sigma = float(timestep.max()) + max_sigma = float(sample_sigmas[0]) + min_sigma = float(sample_sigmas[-1]) + sigma_range = max_sigma - min_sigma + if sigma_range > 0: + progress = 1.0 - (current_sigma - min_sigma) / sigma_range + if progress >= _strength: + conds = model_options.get('conds', None) + if conds is not None: + for cond_list in conds.values(): + for cond in cond_list: + if 'model_conds' in cond and 'is_covers' in cond['model_conds']: + cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False) + return executor(x, timestep, model_options, seed) + self.current_patcher.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.PREDICT_NOISE, + 'ace_step_cover_strength', + audio_cover_strength_wrapper + ) + return out class Omnigen2(BaseModel): From bd27c7201acc3419f0fab26b4afb1c4fce0690c4 Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:58:45 -0300 Subject: [PATCH 3/6] Update model_base.py --- comfy/model_base.py | 167 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 146 insertions(+), 21 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9e6d48429..1394252b7 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -51,6 +52,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.rt_detr.rtdetr_v4 import comfy.model_management import comfy.patcher_extension @@ -76,6 +78,7 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 + IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -108,6 +111,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow + elif model_type == ModelType.IMG_TO_IMG_FLOW: + c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -282,6 +287,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -880,7 +891,7 @@ class Flux(BaseModel): return torch.cat((image, mask), dim=1) def encode_adm(self, **kwargs): - return kwargs["pooled_output"] + return kwargs.get("pooled_output", None) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -922,6 +933,26 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class LongCatImage(Flux): + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + transformer_options = transformer_options.copy() + rope_opts = transformer_options.get("rope_options", {}) + rope_opts = dict(rope_opts) + pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 + rope_opts.setdefault("shift_t", 1.0) + rope_opts.setdefault("shift_y", pe_len) + rope_opts.setdefault("shift_x", pe_len) + transformer_options["rope_options"] = rope_opts + return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def encode_adm(self, **kwargs): + return None + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out.pop('guidance', None) + return out + class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -971,6 +1002,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -988,10 +1023,14 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) + device = kwargs["device"] + if attention_mask is not None: out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: + if hasattr(self.diffusion_model, "preprocess_text_embeds"): + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) @@ -1019,6 +1058,14 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + + ref_audio = kwargs.get("ref_audio", None) + if ref_audio is not None: + out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) + return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): @@ -1229,6 +1276,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1336,6 +1388,11 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1388,6 +1445,11 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1405,6 +1467,13 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1441,6 +1510,11 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1462,6 +1536,50 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_FlowRVS(WAN21): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): + model_config.unet_config["model_type"] = "t2v" + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) @@ -1547,6 +1665,24 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out +def ace_audio_cover_strength_wrapper(strength, executor, x, timestep, model_options={}, seed=None): + sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None) + if sample_sigmas is not None: + current_sigma = float(timestep.max()) + max_sigma = float(sample_sigmas[0]) + min_sigma = float(sample_sigmas[-1]) + sigma_range = max_sigma - min_sigma + if sigma_range > 0: + progress = 1.0 - (current_sigma - min_sigma) / sigma_range + if progress >= strength: + conds = model_options.get('conds', None) + if conds is not None: + for cond_list in conds.values(): + for cond in cond_list: + if 'model_conds' in cond and 'is_covers' in cond['model_conds']: + cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False) + return executor(x, timestep, model_options, seed) + class ACEStep15(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) @@ -1589,32 +1725,17 @@ class ACEStep15(BaseModel): out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) - audio_cover_strength = kwargs.get('audio_cover_strength', 1.0) + audio_cover_strength = max(0.0, min(1.0, float(kwargs.get('audio_cover_strength', 1.0)))) is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False + if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None: if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'): - _strength = audio_cover_strength - def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None): - sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None) - if sample_sigmas is not None: - current_sigma = float(timestep.max()) - max_sigma = float(sample_sigmas[0]) - min_sigma = float(sample_sigmas[-1]) - sigma_range = max_sigma - min_sigma - if sigma_range > 0: - progress = 1.0 - (current_sigma - min_sigma) / sigma_range - if progress >= _strength: - conds = model_options.get('conds', None) - if conds is not None: - for cond_list in conds.values(): - for cond in cond_list: - if 'model_conds' in cond and 'is_covers' in cond['model_conds']: - cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False) - return executor(x, timestep, model_options, seed) + import functools + bound_wrapper = functools.partial(ace_audio_cover_strength_wrapper, audio_cover_strength) self.current_patcher.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength', - audio_cover_strength_wrapper + bound_wrapper ) return out @@ -1869,3 +1990,7 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None + +class RT_DETR_v4(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) From d1aeda9255226e5d81be5ff353c7d04f4da5896b Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Thu, 9 Apr 2026 20:59:28 -0300 Subject: [PATCH 4/6] Update nodes_ace.py --- comfy_extras/nodes_ace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index de55080d4..ad4455672 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -82,7 +82,7 @@ class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return io.NodeOutput({"samples": latent, "type": "audio"}) @@ -105,7 +105,7 @@ class EmptyAceStep15LatentAudio(io.ComfyNode): @classmethod def execute(cls, seconds, batch_size) -> io.NodeOutput: length = round((seconds * 48000 / 1920)) - latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return io.NodeOutput({"samples": latent, "type": "audio"}) class ReferenceAudio(io.ComfyNode): From ac373bef57082485743ded26c25a5d84f076d708 Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:13:18 -0300 Subject: [PATCH 5/6] Update nodes_ace.py --- comfy_extras/nodes_ace.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index ad4455672..de55080d4 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -82,7 +82,7 @@ class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def execute(cls, seconds, batch_size) -> io.NodeOutput: length = int(seconds * 44100 / 512 / 8) - latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "audio"}) @@ -105,7 +105,7 @@ class EmptyAceStep15LatentAudio(io.ComfyNode): @classmethod def execute(cls, seconds, batch_size) -> io.NodeOutput: length = round((seconds * 48000 / 1920)) - latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) return io.NodeOutput({"samples": latent, "type": "audio"}) class ReferenceAudio(io.ComfyNode): From debeab6b531ecc69d8f07fa3bc9a27ace88e55b9 Mon Sep 17 00:00:00 2001 From: Subarasheese <35178075+Subarasheese@users.noreply.github.com> Date: Thu, 9 Apr 2026 21:13:47 -0300 Subject: [PATCH 6/6] Update model_base.py --- comfy/model_base.py | 167 ++++++-------------------------------------- 1 file changed, 21 insertions(+), 146 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 1394252b7..9e6d48429 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,7 +21,6 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model -import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -52,7 +51,6 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 -import comfy.ldm.rt_detr.rtdetr_v4 import comfy.model_management import comfy.patcher_extension @@ -78,7 +76,6 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 - IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -111,8 +108,6 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow - elif model_type == ModelType.IMG_TO_IMG_FLOW: - c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -287,12 +282,6 @@ class BaseModel(torch.nn.Module): return data return None - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - """Override in subclasses to handle model-specific cond slicing for context windows. - Return a sliced cond object, or None to fall through to default handling. - Use comfy.context_windows.slice_cond() for common cases.""" - return None - def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -891,7 +880,7 @@ class Flux(BaseModel): return torch.cat((image, mask), dim=1) def encode_adm(self, **kwargs): - return kwargs.get("pooled_output", None) + return kwargs["pooled_output"] def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -933,26 +922,6 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out -class LongCatImage(Flux): - def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): - transformer_options = transformer_options.copy() - rope_opts = transformer_options.get("rope_options", {}) - rope_opts = dict(rope_opts) - pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 - rope_opts.setdefault("shift_t", 1.0) - rope_opts.setdefault("shift_y", pe_len) - rope_opts.setdefault("shift_x", pe_len) - transformer_options["rope_options"] = rope_opts - return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) - - def encode_adm(self, **kwargs): - return None - - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - out.pop('guidance', None) - return out - class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1002,10 +971,6 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) - guide_attention_entries = kwargs.get("guide_attention_entries", None) - if guide_attention_entries is not None: - out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) - return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -1023,14 +988,10 @@ class LTXAV(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) - device = kwargs["device"] - if attention_mask is not None: out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: - if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) @@ -1058,14 +1019,6 @@ class LTXAV(BaseModel): if latent_shapes is not None: out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) - guide_attention_entries = kwargs.get("guide_attention_entries", None) - if guide_attention_entries is not None: - out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) - - ref_audio = kwargs.get("ref_audio", None) - if ref_audio is not None: - out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) - return out def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): @@ -1276,11 +1229,6 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out -class ZImagePixelSpace(Lumina2): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) - self.memory_usage_factor_conds = ("ref_latents",) - class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1388,11 +1336,6 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - if cond_key == "vace_context": - return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) - return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) - class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1445,11 +1388,6 @@ class WAN21_HuMo(WAN21): return out - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - if cond_key == "audio_embed": - return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) - return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) - class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1467,13 +1405,6 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - if cond_key == "face_pixel_values": - return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) - if cond_key == "pose_latents": - return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) - return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) - class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1510,11 +1441,6 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out - def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): - if cond_key == "audio_embed": - return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) - return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) - class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1536,50 +1462,6 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image -class WAN21_FlowRVS(WAN21): - def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): - model_config.unet_config["model_type"] = "t2v" - super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) - self.image_to_video = image_to_video - -class WAN21_SCAIL(WAN21): - def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): - super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) - self.memory_usage_factor_conds = ("reference_latent", "pose_latents") - self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} - self.image_to_video = image_to_video - - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - - reference_latents = kwargs.get("reference_latents", None) - if reference_latents is not None: - ref_latent = self.process_latent_in(reference_latents[-1]) - ref_mask = torch.ones_like(ref_latent[:, :4]) - ref_latent = torch.cat([ref_latent, ref_mask], dim=1) - out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) - - pose_latents = kwargs.get("pose_video_latent", None) - if pose_latents is not None: - pose_latents = self.process_latent_in(pose_latents) - pose_mask = torch.ones_like(pose_latents[:, :4]) - pose_latents = torch.cat([pose_latents, pose_mask], dim=1) - out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) - - return out - - def extra_conds_shapes(self, **kwargs): - out = {} - ref_latents = kwargs.get("reference_latents", None) - if ref_latents is not None: - out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) - - pose_latents = kwargs.get("pose_video_latent", None) - if pose_latents is not None: - out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] - - return out - class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) @@ -1665,24 +1547,6 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out -def ace_audio_cover_strength_wrapper(strength, executor, x, timestep, model_options={}, seed=None): - sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None) - if sample_sigmas is not None: - current_sigma = float(timestep.max()) - max_sigma = float(sample_sigmas[0]) - min_sigma = float(sample_sigmas[-1]) - sigma_range = max_sigma - min_sigma - if sigma_range > 0: - progress = 1.0 - (current_sigma - min_sigma) / sigma_range - if progress >= strength: - conds = model_options.get('conds', None) - if conds is not None: - for cond_list in conds.values(): - for cond in cond_list: - if 'model_conds' in cond and 'is_covers' in cond['model_conds']: - cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False) - return executor(x, timestep, model_options, seed) - class ACEStep15(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) @@ -1725,17 +1589,32 @@ class ACEStep15(BaseModel): out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) - audio_cover_strength = max(0.0, min(1.0, float(kwargs.get('audio_cover_strength', 1.0)))) + audio_cover_strength = kwargs.get('audio_cover_strength', 1.0) is_cover_mode = out.get('is_covers', comfy.conds.CONDConstant(None)).cond != False - if audio_cover_strength < 1.0 and is_cover_mode and self.current_patcher is not None: if not self.current_patcher.get_wrappers(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength'): - import functools - bound_wrapper = functools.partial(ace_audio_cover_strength_wrapper, audio_cover_strength) + _strength = audio_cover_strength + def audio_cover_strength_wrapper(executor, x, timestep, model_options={}, seed=None): + sample_sigmas = model_options.get('transformer_options', {}).get('sample_sigmas', None) + if sample_sigmas is not None: + current_sigma = float(timestep.max()) + max_sigma = float(sample_sigmas[0]) + min_sigma = float(sample_sigmas[-1]) + sigma_range = max_sigma - min_sigma + if sigma_range > 0: + progress = 1.0 - (current_sigma - min_sigma) / sigma_range + if progress >= _strength: + conds = model_options.get('conds', None) + if conds is not None: + for cond_list in conds.values(): + for cond in cond_list: + if 'model_conds' in cond and 'is_covers' in cond['model_conds']: + cond['model_conds']['is_covers'] = comfy.conds.CONDConstant(False) + return executor(x, timestep, model_options, seed) self.current_patcher.add_wrapper_with_key( comfy.patcher_extension.WrappersMP.PREDICT_NOISE, 'ace_step_cover_strength', - bound_wrapper + audio_cover_strength_wrapper ) return out @@ -1990,7 +1869,3 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None - -class RT_DETR_v4(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)