Fix I2V, add necessary latent post process nodes

This commit is contained in:
kijai 2025-11-26 01:34:14 +02:00
parent 0817e591ad
commit 292c3576c2
5 changed files with 109 additions and 24 deletions

View File

@ -28,10 +28,7 @@ def get_shift_scale_gate(params):
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0):
return torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=dim, dtype=torch.float32)
/ dim)
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
class TimeEmbeddings(nn.Module):
@ -354,16 +351,22 @@ class Kandinsky5(nn.Module):
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, transformer_options=transformer_options, **kwargs)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -1654,16 +1654,8 @@ class Kandinsky5(BaseModel):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
image = torch.zeros_like(noise)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
@ -1677,7 +1669,6 @@ class Kandinsky5(BaseModel):
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
@ -1687,4 +1678,8 @@ class Kandinsky5(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
time_dim_replace = kwargs.get("time_dim_replace", None)
if time_dim_replace is not None:
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
return out

View File

@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
time_dim_replace: NotRequired[torch.Tensor]
'''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList

View File

@ -7,6 +7,7 @@ import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Kandinsky5ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -26,28 +27,82 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
concat_latent_image = latent.clone()
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
cond_latent_out["samples"] = encoded
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
def adaptive_mean_std_normalization(source, reference):
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
#magic constants - limit changes in latents
clump_mean_low = 0.05
clump_mean_high = 0.1
clump_std_low = 0.1
clump_std_high = 0.25
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
# normalization
normalized = (source - source_mean) / (source_std + 1e-8)
normalized = normalized * reference_std + reference_mean
return normalized
class NormalizeVideoLatentFrames(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="NormalizeVideoLatentFrames",
category="conditioning/video_models",
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames.",
inputs=[
io.Latent.Input("latent"),
io.Int.Input("frames_to_normalize", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of initial frames to normalize, counted from the start"),
io.Int.Input("reference_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames after the normalized frames to use as reference"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, latent, frames_to_normalize, reference_frames) -> io.NodeOutput:
if latent["samples"].shape[2] <= 1:
return latent
s = latent.copy()
samples = latent["samples"].clone()
first_frames = samples[:, :, :frames_to_normalize]
reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)]
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
samples[:, :, :frames_to_normalize] = normalized_first_frames
s["samples"] = samples
return io.NodeOutput(s)
class Kandinsky5Extension(ComfyExtension):
@ -55,6 +110,7 @@ class Kandinsky5Extension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Kandinsky5ImageToVideo,
NormalizeVideoLatentFrames
]
async def comfy_entrypoint() -> Kandinsky5Extension:

View File

@ -388,6 +388,34 @@ class LatentOperationSharpen(io.ComfyNode):
return luminance * sharpened
return io.NodeOutput(sharpen)
class ReplaceVideoLatentFrames(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReplaceVideoLatentFrames",
category="latent/batch",
inputs=[
io.Latent.Input("destination"),
io.Latent.Input("source"),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, source, destination, index) -> io.NodeOutput:
if index > destination["samples"].shape[2]:
raise RuntimeError(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {destination['samples'].shape[2]}.")
if index + source["samples"].shape[2] > destination["samples"].shape[2]:
raise RuntimeError(f"ReplaceVideoLatentFrames: Source latent frames {source['samples'].shape[2]} do not fit within destination latent frames {destination['samples'].shape[2]} at the specified index {index}.")
s = source.copy()
s_source = source["samples"]
s_destination = destination["samples"].clone()
s_destination[:, :, index:index + s_source.shape[2]] = s_source
s["samples"] = s_destination
return io.NodeOutput(s)
class LatentExtension(ComfyExtension):
@override
@ -405,6 +433,7 @@ class LatentExtension(ComfyExtension):
LatentApplyOperationCFG,
LatentOperationTonemapReinhard,
LatentOperationSharpen,
ReplaceVideoLatentFrames
]