diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index ca8867eaf..145e6e69a 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -19,6 +19,10 @@ import torch from torch import nn from torch.autograd import Function +import comfy.ops + +ops = comfy.ops.disable_weight_init + class vector_quantize(Function): @staticmethod @@ -121,15 +125,15 @@ class ResBlock(nn.Module): self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.depthwise = nn.Sequential( nn.ReplicationPad2d(1), - nn.Conv2d(c, c, kernel_size=3, groups=c) + ops.Conv2d(c, c, kernel_size=3, groups=c) ) # channelwise self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6) self.channelwise = nn.Sequential( - nn.Linear(c, c_hidden), + ops.Linear(c, c_hidden), nn.GELU(), - nn.Linear(c_hidden, c), + ops.Linear(c_hidden, c), ) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) @@ -171,16 +175,16 @@ class StageA(nn.Module): # Encoder blocks self.in_block = nn.Sequential( nn.PixelUnshuffle(2), - nn.Conv2d(3 * 4, c_levels[0], kernel_size=1) + ops.Conv2d(3 * 4, c_levels[0], kernel_size=1) ) down_blocks = [] for i in range(levels): if i > 0: - down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) + down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) block = ResBlock(c_levels[i], c_levels[i] * 4) down_blocks.append(block) down_blocks.append(nn.Sequential( - nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), + ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1 )) self.down_blocks = nn.Sequential(*down_blocks) @@ -191,7 +195,7 @@ class StageA(nn.Module): # Decoder blocks up_blocks = [nn.Sequential( - nn.Conv2d(c_latent, c_levels[-1], kernel_size=1) + ops.Conv2d(c_latent, c_levels[-1], kernel_size=1) )] for i in range(levels): for j in range(bottleneck_blocks if i == 0 else 1): @@ -199,11 +203,11 @@ class StageA(nn.Module): up_blocks.append(block) if i < levels - 1: up_blocks.append( - nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, + ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1)) self.up_blocks = nn.Sequential(*up_blocks) self.out_block = nn.Sequential( - nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1), + ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1), nn.PixelShuffle(2), ) @@ -232,17 +236,17 @@ class Discriminator(nn.Module): super().__init__() d = max(depth - 3, 3) layers = [ - nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), + nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)), nn.LeakyReLU(0.2), ] for i in range(depth - 1): c_in = c_hidden // (2 ** max((d - i), 0)) c_out = c_hidden // (2 ** max((d - 1 - i), 0)) - layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) + layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1))) layers.append(nn.InstanceNorm2d(c_out)) layers.append(nn.LeakyReLU(0.2)) self.encoder = nn.Sequential(*layers) - self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) + self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1) self.logits = nn.Sigmoid() def forward(self, x, cond=None): diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py index 0cb7c49fc..b467a70a8 100644 --- a/comfy/ldm/cascade/stage_c_coder.py +++ b/comfy/ldm/cascade/stage_c_coder.py @@ -19,6 +19,9 @@ import torch import torchvision from torch import nn +import comfy.ops + +ops = comfy.ops.disable_weight_init # EfficientNet class EfficientNetEncoder(nn.Module): @@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module): super().__init__() self.backbone = torchvision.models.efficientnet_v2_s().features.eval() self.mapper = nn.Sequential( - nn.Conv2d(1280, c_latent, kernel_size=1, bias=False), + ops.Conv2d(1280, c_latent, kernel_size=1, bias=False), nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1 ) self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406])) @@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module): def forward(self, x): x = x * 0.5 + 0.5 - x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1]) + x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype) o = self.mapper(self.backbone(x)) return o @@ -44,39 +47,39 @@ class Previewer(nn.Module): def __init__(self, c_in=16, c_hidden=512, c_out=3): super().__init__() self.blocks = nn.Sequential( - nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels + ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), + ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden), - nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 + ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32 nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 2), - nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 + ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 + ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128 nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), + ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1), nn.GELU(), nn.BatchNorm2d(c_hidden // 4), - nn.Conv2d(c_hidden // 4, c_out, kernel_size=1), + ops.Conv2d(c_hidden // 4, c_out, kernel_size=1), ) def forward(self, x): diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 59a62e0df..76af967e6 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -105,7 +105,9 @@ class Modulation(nn.Module): self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: - out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) + if vec.ndim == 2: + vec = vec[:, None, :] + out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1) return ( ModulationOut(*out[:3]), @@ -113,6 +115,20 @@ class Modulation(nn.Module): ) +def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): + if modulation_dims is None: + if m_add is not None: + return tensor * m_mult + m_add + else: + return tensor * m_mult + else: + for d in modulation_dims: + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + if m_add is not None: + tensor[:, d[0]:d[1]] += m_add[:, d[2]] + return tensor + + class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): super().__init__() @@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module): ) self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None): + def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None): img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) - img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) - txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) @@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module): txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) + img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -228,9 +244,9 @@ class SingleStreamBlock(nn.Module): self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) - def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor: + def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor: mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1((1 + mod.scale) * self.pre_norm(x) + mod.shift), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k = self.norm(q, k, v) @@ -239,7 +255,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe, mask=attn_mask) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x += apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x @@ -252,8 +268,11 @@ class LastLayer(nn.Module): self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) - def forward(self, x: Tensor, vec: Tensor) -> Tensor: - shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) - x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: + if vec.ndim == 2: + vec = vec[:, None, :] + + shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1) + x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims) x = self.linear(x) return x diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index f3f445843..72af3d5bb 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -227,6 +227,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor, guidance: Tensor = None, + guiding_frame_index=None, control=None, transformer_options={}, ) -> Tensor: @@ -237,7 +238,17 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if guiding_frame_index is not None: + token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) + modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] + modulation_dims_txt = [(0, None, 1)] + else: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + modulation_dims = None + modulation_dims_txt = None if self.params.guidance_embed: if guidance is not None: @@ -264,14 +275,14 @@ class HunyuanVideo(nn.Module): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) + out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"]) return out - out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap}) txt = out["txt"] img = out["img"] else: - img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask) + img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt) if control is not None: # Controlnet control_i = control.get("input") @@ -286,13 +297,13 @@ class HunyuanVideo(nn.Module): if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"]) + out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"]) return out - out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap}) + out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims) if control is not None: # Controlnet control_o = control.get("output") @@ -303,7 +314,7 @@ class HunyuanVideo(nn.Module): img = img[:, : img_len] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) shape = initial_shape[-3:] for i in range(len(shape)): @@ -313,7 +324,7 @@ class HunyuanVideo(nn.Module): img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) return img - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs): bs, c, t, h, w = x.shape patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) @@ -325,5 +336,5 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 2fa1ee911..bf4ebefa1 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -898,20 +898,31 @@ class HunyuanVideo(BaseModel): guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + guiding_frame_index = kwargs.get("guiding_frame_index", None) + if guiding_frame_index is not None: + out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) + return out + def scale_latent_inpaint(self, latent_image, **kwargs): + return latent_image class HunyuanVideoI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image", "mask_inverted") + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image",) + def scale_latent_inpaint(self, latent_image, **kwargs): + return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): diff --git a/comfy/model_management.py b/comfy/model_management.py index e7b4d5f0d..2309c1f43 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -582,7 +582,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) if vram_set_state == VRAMState.NO_VRAM: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 8a1f8fb63..e291158ce 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1089,7 +1089,6 @@ class ModelPatcher: def patch_hooks(self, hooks: comfy.hooks.HookGroup): with self.use_ejected(): - self.unpatch_hooks() if hooks is not None: model_sd_keys = list(self.model_state_dict().keys()) memory_counter = None @@ -1100,12 +1099,16 @@ class ModelPatcher: # if have cached weights for hooks, use it cached_weights = self.cached_hook_patches.get(hooks, None) if cached_weights is not None: + model_sd_keys_set = set(model_sd_keys) for key in cached_weights: if key not in model_sd_keys: logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") continue self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) + model_sd_keys_set.remove(key) + self.unpatch_hooks(model_sd_keys_set) else: + self.unpatch_hooks() relevant_patches = self.get_combined_hook_patches(hooks=hooks) original_weights = None if len(relevant_patches) > 0: @@ -1116,6 +1119,8 @@ class ModelPatcher: continue self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, memory_counter=memory_counter) + else: + self.unpatch_hooks() self.current_hooks = hooks def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): @@ -1172,17 +1177,23 @@ class ModelPatcher: del out_weight del weight - def unpatch_hooks(self) -> None: + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: with self.use_ejected(): if len(self.hook_backup) == 0: self.current_hooks = None return keys = list(self.hook_backup.keys()) - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + if whitelist_keys_set: + for k in keys: + if k in whitelist_keys_set: + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) + self.hook_backup.pop(k) + else: + for k in keys: + comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1])) - self.hook_backup.clear() - self.current_hooks = None + self.hook_backup.clear() + self.current_hooks = None def clean_hooks(self): self.unpatch_hooks() diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 56aef9b01..504010ad0 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -68,7 +68,6 @@ class TextEncodeHunyuanVideo_ImageToVideo: tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave) return (clip.encode_from_tokens_scheduled(tokens), ) - class HunyuanImageToVideo: @classmethod def INPUT_TYPES(s): @@ -78,6 +77,7 @@ class HunyuanImageToVideo: "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), + "guidance_type": (["v1 (concat)", "v2 (replace)"], ) }, "optional": {"start_image": ("IMAGE", ), }} @@ -88,8 +88,10 @@ class HunyuanImageToVideo: CATEGORY = "conditioning/video_models" - def encode(self, positive, vae, width, height, length, batch_size, start_image=None): + def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None): latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + out_latent = {} + if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) @@ -97,13 +99,20 @@ class HunyuanImageToVideo: mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + if guidance_type == "v1 (concat)": + cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask} + else: + cond = {'guiding_frame_index': 0} + latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image + out_latent["noise_mask"] = mask + + positive = node_helpers.conditioning_set_values(positive, cond) - out_latent = {} out_latent["samples"] = latent return (positive, out_latent) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, diff --git a/comfyui_version.py b/comfyui_version.py index a68a65323..b5e6fbead 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.24" +__version__ = "0.3.26" diff --git a/pyproject.toml b/pyproject.toml index 4c11c71bb..f13fed8dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.24" +version = "0.3.26" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9"