diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 4e08b6c08..a8127d18e 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -122,7 +122,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json") elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json") - elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd or "vision_model.encoder.layers.11.layer_norm1.weight" in sd: + elif "vision_model.encoder.layers.11.layer_norm1.weight" in sd: embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] norm_weight = sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] if norm_weight == 1152: diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index e9ca258ba..b8705fb8d 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -635,6 +635,24 @@ class SingleStreamBlock(nn.Module): return x +def trim_repeats(expanded): + _, L, D = expanded.shape + seq = expanded[0] + + repeat_len = L + for k in range(1, L // 2 + 1): + if torch.equal(seq[:k], seq[k:2*k]): + repeat_len = k + break + + repeat_dim = D + for k in range(1, D // 2 + 1): + if torch.equal(seq[:, :k], seq[:, k:2*k]): + repeat_dim = k + break + + return expanded[:, :repeat_len, :repeat_dim] + class HunyuanVideoFoley(nn.Module): def __init__( self, @@ -810,18 +828,30 @@ class HunyuanVideoFoley(nn.Module): self, x: torch.Tensor, t: torch.Tensor, - full_cond: torch.Tensor, + context: torch.Tensor, + control = None, transformer_options = {}, drop_visual: Optional[List[bool]] = None, ): + device = x.device audio = x bs, _, ol = x.shape tl = ol // self.patch_size - condition, uncondition = torch.chunk(2, full_cond) - uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition) - clip_feat, sync_feat, cond = torch.chunk(3, condition) - clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond]) + condition, uncondition = torch.chunk(context, 2) + + condition = condition.view(3, context.size(1) // 3, -1) + uncondition = uncondition.view(3, context.size(1) // 3, -1) + + uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) + clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) + cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg) + + uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True) + uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True) + cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True) + + clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) if drop_visual is not None: clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) diff --git a/comfy/text_encoders/clap_model.py b/comfy/text_encoders/clap_model.py index 27bebd762..e992a52b7 100644 --- a/comfy/text_encoders/clap_model.py +++ b/comfy/text_encoders/clap_model.py @@ -351,7 +351,7 @@ class ClapTextModelWithProjection(nn.Module): pooled_output = text_outputs[1] text_embeds = self.text_projection(pooled_output) - return text_embeds, text_outputs[0] + return text_outputs[0], torch.tensor([]), text_embeds class ClapTextEncoderModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py index 70c3d3d4e..89eaf2394 100644 --- a/comfy_extras/nodes_hunyuan_foley.py +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -36,20 +36,45 @@ class HunyuanFoleyConditioning(io.ComfyNode): inputs = [ io.Conditioning.Input("siglip_encoding_1"), io.Conditioning.Input("synchformer_encoding_2"), - io.Conditioning.Input("text_encoding"), + io.Conditioning.Input("text_encoding_positive"), + io.Conditioning.Input("text_encoding_negative"), ], outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")] ) @classmethod - def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding): + def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative): + + text_encoding_positive = text_encoding_positive[0][0] + text_encoding_negative = text_encoding_negative[0][0] + all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative) - if isinstance(text_encoding, list): - text_encoding = text_encoding[0] + max_l = max([t.size(1) for t in all_]) + max_d = max([t.size(2) for t in all_]) + + def repeat_shapes(max_value, input, dim = 1): + # temporary repeat values on the cpu + factor_pos, remainder = divmod(max_value, input.shape[dim]) + + positions = [1] * input.ndim + positions[dim] = factor_pos + input = input.cpu().repeat(*positions) + + if remainder > 0: + pad = input[:, :remainder, :] + input = torch.cat([input, pad], dim =1) + + return input + + siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_] + siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_] + + embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0) + + x = siglip_encoding_1 + negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]] + positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]] - embeds = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding], dim = 0) - positive = [[embeds, {}]] - negative = [[torch.zeros_like(embeds), {}]] return io.NodeOutput(positive, negative) class FoleyExtension(ComfyExtension): diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 485d38804..e9509500f 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -59,10 +59,11 @@ class EncodeVideo(io.ComfyNode): raise ValueError("Must either have vae or clip_vision.") elif vae is None and clip_vision is None: raise ValueError("Can't have VAE and Clip Vision passed at the same time!") + model = vae.first_stage_model if vae is not None else clip_vision.model vae = vae if vae is not None else clip_vision - if hasattr(vae.first_stage_model, "video_encoding"): - data, num_segments, output_fn = vae.first_stage_model.video_encoding(video, step_size) + if hasattr(model, "video_encoding"): + data, num_segments, output_fn = model.video_encoding(video, step_size) batch_size = b * num_segments else: data = video.view(batch_size, c, h, w) @@ -77,7 +78,11 @@ class EncodeVideo(io.ComfyNode): with torch.inference_mode(): for i in range(0, total, batch_size): chunk = data[i : i + batch_size] - out = vae.encode(chunk) + if hasattr(vae, "encode"): + out = vae.encode(chunk) + else: + out = vae.encode_image(chunk) + out = out["image_embeds"] outputs.append(out) del out, chunk torch.cuda.empty_cache()