Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-04-17 20:51:12 +03:00 committed by GitHub
commit d9211bbf1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 85 additions and 31 deletions

View File

@ -193,7 +193,7 @@ class WanAttentionBlock(nn.Module):
e,
freqs,
context,
context_img_len=None,
context_img_len=257,
):
r"""
Args:
@ -251,7 +251,7 @@ class Head(nn.Module):
class MLPProj(torch.nn.Module):
def __init__(self, in_dim, out_dim, operation_settings={}):
def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
super().__init__()
self.proj = torch.nn.Sequential(
@ -259,7 +259,15 @@ class MLPProj(torch.nn.Module):
torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
if flf_pos_embed_token_number is not None:
self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
else:
self.emb_pos = None
def forward(self, image_embeds):
if self.emb_pos is not None:
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -285,6 +293,7 @@ class WanModel(torch.nn.Module):
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
device=None,
dtype=None,
@ -374,7 +383,7 @@ class WanModel(torch.nn.Module):
self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
if model_type == 'i2v':
self.img_emb = MLPProj(1280, dim, operation_settings=operation_settings)
self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
else:
self.img_emb = None

View File

@ -321,6 +321,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "i2v"
else:
dit_config["model_type"] = "t2v"
flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
if flf_weight is not None:
dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

View File

@ -4,6 +4,7 @@ import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
import comfy.clip_vision
class WanImageToVideo:
@ -99,6 +100,72 @@ class WanFunControlToVideo:
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFirstLastFrameToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_start_image": ("CLIP_VISION_OUTPUT", ),
"clip_vision_end_image": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_start_image is not None:
clip_vision_output = clip_vision_start_image
if clip_vision_end_image is not None:
if clip_vision_output is not None:
states = torch.cat([clip_vision_output.penultimate_hidden_states, clip_vision_end_image.penultimate_hidden_states], dim=-2)
clip_vision_output = comfy.clip_vision.Output()
clip_vision_output.penultimate_hidden_states = states
else:
clip_vision_output = clip_vision_end_image
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanFunInpaintToVideo:
@classmethod
def INPUT_TYPES(s):
@ -122,38 +189,13 @@ class WanFunInpaintToVideo:
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_output=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
if end_image is not None:
end_image = comfy.utils.common_upscale(end_image[-length:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
flfv = WanFirstLastFrameToVideo()
return flfv.encode(positive, negative, vae, width, height, length, batch_size, start_image=start_image, end_image=end_image, clip_vision_start_image=clip_vision_output)
image = torch.ones((length, height, width, 3)) * 0.5
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
if start_image is not None:
image[:start_image.shape[0]] = start_image
mask[:, :, :start_image.shape[0] + 3] = 0.0
if end_image is not None:
image[-end_image.shape[0]:] = end_image
mask[:, :, -end_image.shape[0]:] = 0.0
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
"WanFunControlToVideo": WanFunControlToVideo,
"WanFunInpaintToVideo": WanFunInpaintToVideo,
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
}