mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
Fix I2V, add necessary latent post process nodes
This commit is contained in:
parent
0817e591ad
commit
292c3576c2
@ -28,10 +28,7 @@ def get_shift_scale_gate(params):
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
||||
/ dim)
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
@ -354,16 +351,22 @@ class Kandinsky5(nn.Module):
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, transformer_options=transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
|
||||
@ -1654,16 +1654,8 @@ class Kandinsky5(BaseModel):
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros_like(noise)
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
image = self.process_latent_in(image)
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
@ -1677,7 +1669,6 @@ class Kandinsky5(BaseModel):
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
@ -1687,4 +1678,8 @@ class Kandinsky5(BaseModel):
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
||||
'''Used by WAN Camera.'''
|
||||
time_dim_concat: NotRequired[torch.Tensor]
|
||||
'''Used by WAN Phantom Subject.'''
|
||||
time_dim_replace: NotRequired[torch.Tensor]
|
||||
'''Used by Kandinsky5 I2V.'''
|
||||
|
||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||
Type = CondList
|
||||
|
||||
@ -7,6 +7,7 @@ import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -26,28 +27,82 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond_latent_out = {}
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent_image = latent.clone()
|
||||
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
|
||||
cond_latent_out["samples"] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||
|
||||
|
||||
def adaptive_mean_std_normalization(source, reference):
|
||||
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||
#magic constants - limit changes in latents
|
||||
clump_mean_low = 0.05
|
||||
clump_mean_high = 0.1
|
||||
clump_std_low = 0.1
|
||||
clump_std_high = 0.25
|
||||
|
||||
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||
|
||||
# normalization
|
||||
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||
normalized = normalized * reference_std + reference_mean
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class NormalizeVideoLatentFrames(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="NormalizeVideoLatentFrames",
|
||||
category="conditioning/video_models",
|
||||
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames.",
|
||||
inputs=[
|
||||
io.Latent.Input("latent"),
|
||||
io.Int.Input("frames_to_normalize", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of initial frames to normalize, counted from the start"),
|
||||
io.Int.Input("reference_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames after the normalized frames to use as reference"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latent, frames_to_normalize, reference_frames) -> io.NodeOutput:
|
||||
if latent["samples"].shape[2] <= 1:
|
||||
return latent
|
||||
s = latent.copy()
|
||||
samples = latent["samples"].clone()
|
||||
|
||||
first_frames = samples[:, :, :frames_to_normalize]
|
||||
reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)]
|
||||
|
||||
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||
|
||||
samples[:, :, :frames_to_normalize] = normalized_first_frames
|
||||
s["samples"] = samples
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
class Kandinsky5Extension(ComfyExtension):
|
||||
@ -55,6 +110,7 @@ class Kandinsky5Extension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
Kandinsky5ImageToVideo,
|
||||
NormalizeVideoLatentFrames
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||
|
||||
@ -388,6 +388,34 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
return luminance * sharpened
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReplaceVideoLatentFrames",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("destination"),
|
||||
io.Latent.Input("source"),
|
||||
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, source, destination, index) -> io.NodeOutput:
|
||||
if index > destination["samples"].shape[2]:
|
||||
raise RuntimeError(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {destination['samples'].shape[2]}.")
|
||||
if index + source["samples"].shape[2] > destination["samples"].shape[2]:
|
||||
raise RuntimeError(f"ReplaceVideoLatentFrames: Source latent frames {source['samples'].shape[2]} do not fit within destination latent frames {destination['samples'].shape[2]} at the specified index {index}.")
|
||||
s = source.copy()
|
||||
s_source = source["samples"]
|
||||
s_destination = destination["samples"].clone()
|
||||
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||
s["samples"] = s_destination
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class LatentExtension(ComfyExtension):
|
||||
@override
|
||||
@ -405,6 +433,7 @@ class LatentExtension(ComfyExtension):
|
||||
LatentApplyOperationCFG,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentOperationSharpen,
|
||||
ReplaceVideoLatentFrames
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user