diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 9b69c85a1..e44048447 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1321,6 +1321,14 @@ class NaDiT(nn.Module): layers=["out"], modes=["in"], ) + + self.stop_cfg_index = -1 + + def set_cfg_stop_index(self, cfg): + self.stop_cfg_index = cfg + + def get_cfg_stop_index(self): + return self.stop_cfg_index def forward( self, @@ -1335,14 +1343,17 @@ class NaDiT(nn.Module): blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - pos_cond, neg_cond = context.chunk(2, dim=0) - pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - txt, txt_shape = flatten([pos_cond, neg_cond]) + try: + neg_cond, pos_cond = context.chunk(2, dim=0) + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) + txt, txt_shape = flatten([pos_cond, neg_cond]) + except: + txt, txt_shape = flatten([context.squeeze(0)]) vid, vid_shape = flatten(x) cond_latent, _ = flatten(conditions) - vid = torch.cat([cond_latent, vid], dim=-1) + vid = torch.cat([vid, cond_latent], dim=-1) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) @@ -1404,4 +1415,9 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - return torch.stack(vid) + out = torch.stack(vid) + try: + pos, neg = out.chunk(2) + return torch.cat([neg, pos]) + except: + return out diff --git a/comfy/samplers.py b/comfy/samplers.py index 25ccaf39f..c159055dd 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -947,8 +947,21 @@ class CFGGuider: def __call__(self, *args, **kwargs): return self.predict_noise(*args, **kwargs) + + def handle_dynamic_cfg(self, timestep, model_options): + if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"): + stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index + transformer_options = model_options.get("transformer_options", {}) + sigmas = transformer_options.get("sample_sigmas", None) + if sigmas is not None or self.cfg != 1.0: + dist = torch.abs(sigmas - timestep) + i = torch.argmin(dist).item() + + if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2): + self.set_cfg(1.0) def predict_noise(self, x, timestep, model_options={}, seed=None): + self.handle_dynamic_cfg(timestep, model_options) return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ae5d2c563..a42bf2b6a 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -5,6 +5,22 @@ import nodes import torch import node_helpers +class CFGCutoff: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), "cfg_stop_index": ("INT", {"default": -1, "min": -1, })}} + + RETURN_TYPES = ("MODEL",) + FUNCTION = "patch" + + CATEGORY = "advanced/model" + + def patch(self, model, cfg_stop_index): + diff_model = model.model.diffusion_model + if hasattr(diff_model, "set_cfg_stop_index"): + diff_model.set_cfg_stop_index(cfg_stop_index) + + return (model,) class LCM(comfy.model_sampling.EPS): def calculate_denoised(self, sigma, model_output, model_input): @@ -326,4 +342,5 @@ NODE_CLASS_MAPPINGS = { "ModelSamplingFlux": ModelSamplingFlux, "RescaleCFG": RescaleCFG, "ModelComputeDtype": ModelComputeDtype, + "CFGCutoff": CFGCutoff }