diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index f9ee79199..b47aac70e 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -28,10 +28,7 @@ def get_shift_scale_gate(params): return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) def get_freqs(dim, max_period=10000.0): - return torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=dim, dtype=torch.float32) - / dim) + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) class TimeEmbeddings(nn.Module): @@ -354,16 +351,22 @@ class Kandinsky5(nn.Module): visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) - def _forward(self, x, timestep, context, y, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): bs, c, t_len, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) - def forward(self, x, timestep, context, y, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, transformer_options=transformer_options, **kwargs) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9ce40482f..f9e546b1c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1654,16 +1654,8 @@ class Kandinsky5(BaseModel): def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) - - image = kwargs.get("concat_latent_image", None) device = kwargs["device"] - - if image is None: - image = torch.zeros_like(noise) - else: - image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") - image = self.process_latent_in(image) - image = utils.resize_to_batch_size(image, noise.shape[0]) + image = torch.zeros_like(noise) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: @@ -1677,7 +1669,6 @@ class Kandinsky5(BaseModel): return torch.cat((image, mask), dim=1) - def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) @@ -1687,4 +1678,8 @@ class Kandinsky5(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace)) + return out diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 257f07c42..fdcfcdde6 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO): '''Used by WAN Camera.''' time_dim_concat: NotRequired[torch.Tensor] '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' CondList = list[tuple[torch.Tensor, PooledDict]] Type = CondList diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index dd9c73d3a..cb2d83595 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -7,6 +7,7 @@ import comfy.utils from typing_extensions import override from comfy_api.latest import ComfyExtension, io + class Kandinsky5ImageToVideo(io.ComfyNode): @classmethod def define_schema(cls): @@ -26,28 +27,82 @@ class Kandinsky5ImageToVideo(io.ComfyNode): outputs=[ io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), ], ) @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) encoded = vae.encode(start_image[:, :, :, :3]) - concat_latent_image = latent.clone() - concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + cond_latent_out["samples"] = encoded - mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) out_latent = {} out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent) + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + #magic constants - limit changes in latents + clump_mean_low = 0.05 + clump_mean_high = 0.1 + clump_std_low = 0.1 + clump_std_high = 0.25 + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentFrames", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("frames_to_normalize", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of initial frames to normalize, counted from the start"), + io.Int.Input("reference_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames after the normalized frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, frames_to_normalize, reference_frames) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return latent + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :frames_to_normalize] + reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)] + + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :frames_to_normalize] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) class Kandinsky5Extension(ComfyExtension): @@ -55,6 +110,7 @@ class Kandinsky5Extension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ Kandinsky5ImageToVideo, + NormalizeVideoLatentFrames ] async def comfy_entrypoint() -> Kandinsky5Extension: diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d2df07ff9..5b63cb14c 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -388,6 +388,34 @@ class LatentOperationSharpen(io.ComfyNode): return luminance * sharpened return io.NodeOutput(sharpen) +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination"), + io.Latent.Input("source"), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, source, destination, index) -> io.NodeOutput: + if index > destination["samples"].shape[2]: + raise RuntimeError(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {destination['samples'].shape[2]}.") + if index + source["samples"].shape[2] > destination["samples"].shape[2]: + raise RuntimeError(f"ReplaceVideoLatentFrames: Source latent frames {source['samples'].shape[2]} do not fit within destination latent frames {destination['samples'].shape[2]} at the specified index {index}.") + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) class LatentExtension(ComfyExtension): @override @@ -405,6 +433,7 @@ class LatentExtension(ComfyExtension): LatentApplyOperationCFG, LatentOperationTonemapReinhard, LatentOperationSharpen, + ReplaceVideoLatentFrames ]