mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
works
This commit is contained in:
parent
ebd945ce3d
commit
d9f71da998
@ -1321,6 +1321,14 @@ class NaDiT(nn.Module):
|
|||||||
layers=["out"],
|
layers=["out"],
|
||||||
modes=["in"],
|
modes=["in"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.stop_cfg_index = -1
|
||||||
|
|
||||||
|
def set_cfg_stop_index(self, cfg):
|
||||||
|
self.stop_cfg_index = cfg
|
||||||
|
|
||||||
|
def get_cfg_stop_index(self):
|
||||||
|
return self.stop_cfg_index
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -1335,14 +1343,17 @@ class NaDiT(nn.Module):
|
|||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
conditions = kwargs.get("condition")
|
conditions = kwargs.get("condition")
|
||||||
|
|
||||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
try:
|
||||||
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
neg_cond, pos_cond = context.chunk(2, dim=0)
|
||||||
txt, txt_shape = flatten([pos_cond, neg_cond])
|
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
|
||||||
|
txt, txt_shape = flatten([pos_cond, neg_cond])
|
||||||
|
except:
|
||||||
|
txt, txt_shape = flatten([context.squeeze(0)])
|
||||||
|
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
cond_latent, _ = flatten(conditions)
|
cond_latent, _ = flatten(conditions)
|
||||||
|
|
||||||
vid = torch.cat([cond_latent, vid], dim=-1)
|
vid = torch.cat([vid, cond_latent], dim=-1)
|
||||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||||
|
|
||||||
@ -1404,4 +1415,9 @@ class NaDiT(nn.Module):
|
|||||||
|
|
||||||
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
|
||||||
vid = unflatten(vid, vid_shape)
|
vid = unflatten(vid, vid_shape)
|
||||||
return torch.stack(vid)
|
out = torch.stack(vid)
|
||||||
|
try:
|
||||||
|
pos, neg = out.chunk(2)
|
||||||
|
return torch.cat([neg, pos])
|
||||||
|
except:
|
||||||
|
return out
|
||||||
|
|||||||
@ -947,8 +947,21 @@ class CFGGuider:
|
|||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.predict_noise(*args, **kwargs)
|
return self.predict_noise(*args, **kwargs)
|
||||||
|
|
||||||
|
def handle_dynamic_cfg(self, timestep, model_options):
|
||||||
|
if hasattr(self.model_patcher.model.diffusion_model, "stop_cfg_index"):
|
||||||
|
stop_index = self.model_patcher.model.diffusion_model.stop_cfg_index
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
|
sigmas = transformer_options.get("sample_sigmas", None)
|
||||||
|
if sigmas is not None or self.cfg != 1.0:
|
||||||
|
dist = torch.abs(sigmas - timestep)
|
||||||
|
i = torch.argmin(dist).item()
|
||||||
|
|
||||||
|
if stop_index == i or (stop_index == -1 and i == len(sigmas) - 2):
|
||||||
|
self.set_cfg(1.0)
|
||||||
|
|
||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
|
self.handle_dynamic_cfg(timestep, model_options)
|
||||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||||
|
|
||||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||||
|
|||||||
@ -5,6 +5,22 @@ import nodes
|
|||||||
import torch
|
import torch
|
||||||
import node_helpers
|
import node_helpers
|
||||||
|
|
||||||
|
class CFGCutoff:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {"model": ("MODEL",), "cfg_stop_index": ("INT", {"default": -1, "min": -1, })}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
FUNCTION = "patch"
|
||||||
|
|
||||||
|
CATEGORY = "advanced/model"
|
||||||
|
|
||||||
|
def patch(self, model, cfg_stop_index):
|
||||||
|
diff_model = model.model.diffusion_model
|
||||||
|
if hasattr(diff_model, "set_cfg_stop_index"):
|
||||||
|
diff_model.set_cfg_stop_index(cfg_stop_index)
|
||||||
|
|
||||||
|
return (model,)
|
||||||
|
|
||||||
class LCM(comfy.model_sampling.EPS):
|
class LCM(comfy.model_sampling.EPS):
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
@ -326,4 +342,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelSamplingFlux": ModelSamplingFlux,
|
"ModelSamplingFlux": ModelSamplingFlux,
|
||||||
"RescaleCFG": RescaleCFG,
|
"RescaleCFG": RescaleCFG,
|
||||||
"ModelComputeDtype": ModelComputeDtype,
|
"ModelComputeDtype": ModelComputeDtype,
|
||||||
|
"CFGCutoff": CFGCutoff
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user