mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-17 00:43:48 +08:00
Fix I2V, add necessary latent post process nodes
This commit is contained in:
parent
0817e591ad
commit
292c3576c2
@ -28,10 +28,7 @@ def get_shift_scale_gate(params):
|
|||||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||||
|
|
||||||
def get_freqs(dim, max_period=10000.0):
|
def get_freqs(dim, max_period=10000.0):
|
||||||
return torch.exp(
|
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||||
-math.log(max_period)
|
|
||||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
|
||||||
/ dim)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeEmbeddings(nn.Module):
|
class TimeEmbeddings(nn.Module):
|
||||||
@ -354,16 +351,22 @@ class Kandinsky5(nn.Module):
|
|||||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||||
return self.out_layer(visual_embed, time_embed)
|
return self.out_layer(visual_embed, time_embed)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t_len, h, w = x.shape
|
bs, c, t_len, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||||
|
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||||
|
|
||||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, y, transformer_options=transformer_options, **kwargs)
|
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||||
|
|||||||
@ -1654,16 +1654,8 @@ class Kandinsky5(BaseModel):
|
|||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
noise = kwargs.get("noise", None)
|
noise = kwargs.get("noise", None)
|
||||||
|
|
||||||
image = kwargs.get("concat_latent_image", None)
|
|
||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
if image is None:
|
|
||||||
image = torch.zeros_like(noise)
|
|
||||||
else:
|
|
||||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
|
||||||
image = self.process_latent_in(image)
|
|
||||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
|
||||||
|
|
||||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
if mask is None:
|
if mask is None:
|
||||||
@ -1677,7 +1669,6 @@ class Kandinsky5(BaseModel):
|
|||||||
|
|
||||||
return torch.cat((image, mask), dim=1)
|
return torch.cat((image, mask), dim=1)
|
||||||
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
attention_mask = kwargs.get("attention_mask", None)
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
@ -1687,4 +1678,8 @@ class Kandinsky5(BaseModel):
|
|||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
|||||||
'''Used by WAN Camera.'''
|
'''Used by WAN Camera.'''
|
||||||
time_dim_concat: NotRequired[torch.Tensor]
|
time_dim_concat: NotRequired[torch.Tensor]
|
||||||
'''Used by WAN Phantom Subject.'''
|
'''Used by WAN Phantom Subject.'''
|
||||||
|
time_dim_replace: NotRequired[torch.Tensor]
|
||||||
|
'''Used by Kandinsky5 I2V.'''
|
||||||
|
|
||||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||||
Type = CondList
|
Type = CondList
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import comfy.utils
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class Kandinsky5ImageToVideo(io.ComfyNode):
|
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -26,28 +27,82 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
|
|||||||
outputs=[
|
outputs=[
|
||||||
io.Conditioning.Output(display_name="positive"),
|
io.Conditioning.Output(display_name="positive"),
|
||||||
io.Conditioning.Output(display_name="negative"),
|
io.Conditioning.Output(display_name="negative"),
|
||||||
io.Latent.Output(display_name="latent"),
|
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||||
|
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
cond_latent_out = {}
|
||||||
if start_image is not None:
|
if start_image is not None:
|
||||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
encoded = vae.encode(start_image[:, :, :, :3])
|
encoded = vae.encode(start_image[:, :, :, :3])
|
||||||
concat_latent_image = latent.clone()
|
cond_latent_out["samples"] = encoded
|
||||||
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
|
|
||||||
|
|
||||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return io.NodeOutput(positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_mean_std_normalization(source, reference):
|
||||||
|
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||||
|
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||||
|
#magic constants - limit changes in latents
|
||||||
|
clump_mean_low = 0.05
|
||||||
|
clump_mean_high = 0.1
|
||||||
|
clump_std_low = 0.1
|
||||||
|
clump_std_high = 0.25
|
||||||
|
|
||||||
|
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||||
|
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||||
|
normalized = normalized * reference_std + reference_mean
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeVideoLatentFrames(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="NormalizeVideoLatentFrames",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames.",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Int.Input("frames_to_normalize", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of initial frames to normalize, counted from the start"),
|
||||||
|
io.Int.Input("reference_frames", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames after the normalized frames to use as reference"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latent, frames_to_normalize, reference_frames) -> io.NodeOutput:
|
||||||
|
if latent["samples"].shape[2] <= 1:
|
||||||
|
return latent
|
||||||
|
s = latent.copy()
|
||||||
|
samples = latent["samples"].clone()
|
||||||
|
|
||||||
|
first_frames = samples[:, :, :frames_to_normalize]
|
||||||
|
reference_frames_data = samples[:, :, frames_to_normalize:frames_to_normalize+min(reference_frames, samples.shape[2]-frames_to_normalize)]
|
||||||
|
|
||||||
|
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||||
|
|
||||||
|
samples[:, :, :frames_to_normalize] = normalized_first_frames
|
||||||
|
s["samples"] = samples
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|
||||||
class Kandinsky5Extension(ComfyExtension):
|
class Kandinsky5Extension(ComfyExtension):
|
||||||
@ -55,6 +110,7 @@ class Kandinsky5Extension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
Kandinsky5ImageToVideo,
|
Kandinsky5ImageToVideo,
|
||||||
|
NormalizeVideoLatentFrames
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> Kandinsky5Extension:
|
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||||
|
|||||||
@ -388,6 +388,34 @@ class LatentOperationSharpen(io.ComfyNode):
|
|||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return io.NodeOutput(sharpen)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
|
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ReplaceVideoLatentFrames",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("destination"),
|
||||||
|
io.Latent.Input("source"),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, source, destination, index) -> io.NodeOutput:
|
||||||
|
if index > destination["samples"].shape[2]:
|
||||||
|
raise RuntimeError(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {destination['samples'].shape[2]}.")
|
||||||
|
if index + source["samples"].shape[2] > destination["samples"].shape[2]:
|
||||||
|
raise RuntimeError(f"ReplaceVideoLatentFrames: Source latent frames {source['samples'].shape[2]} do not fit within destination latent frames {destination['samples'].shape[2]} at the specified index {index}.")
|
||||||
|
s = source.copy()
|
||||||
|
s_source = source["samples"]
|
||||||
|
s_destination = destination["samples"].clone()
|
||||||
|
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||||
|
s["samples"] = s_destination
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
class LatentExtension(ComfyExtension):
|
class LatentExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -405,6 +433,7 @@ class LatentExtension(ComfyExtension):
|
|||||||
LatentApplyOperationCFG,
|
LatentApplyOperationCFG,
|
||||||
LatentOperationTonemapReinhard,
|
LatentOperationTonemapReinhard,
|
||||||
LatentOperationSharpen,
|
LatentOperationSharpen,
|
||||||
|
ReplaceVideoLatentFrames
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user