work on the conditioning

This commit is contained in:
Yousef Rafat 2025-10-04 00:18:03 +03:00
parent 4241f106dc
commit 663d971830
5 changed files with 77 additions and 17 deletions

View File

@ -122,7 +122,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd or "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
elif "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
norm_weight = sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0]
if norm_weight == 1152:

View File

@ -635,6 +635,24 @@ class SingleStreamBlock(nn.Module):
return x
def trim_repeats(expanded):
_, L, D = expanded.shape
seq = expanded[0]
repeat_len = L
for k in range(1, L // 2 + 1):
if torch.equal(seq[:k], seq[k:2*k]):
repeat_len = k
break
repeat_dim = D
for k in range(1, D // 2 + 1):
if torch.equal(seq[:, :k], seq[:, k:2*k]):
repeat_dim = k
break
return expanded[:, :repeat_len, :repeat_dim]
class HunyuanVideoFoley(nn.Module):
def __init__(
self,
@ -810,18 +828,30 @@ class HunyuanVideoFoley(nn.Module):
self,
x: torch.Tensor,
t: torch.Tensor,
full_cond: torch.Tensor,
context: torch.Tensor,
control = None,
transformer_options = {},
drop_visual: Optional[List[bool]] = None,
):
device = x.device
audio = x
bs, _, ol = x.shape
tl = ol // self.patch_size
condition, uncondition = torch.chunk(2, full_cond)
uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition)
clip_feat, sync_feat, cond = torch.chunk(3, condition)
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond])
condition, uncondition = torch.chunk(context, 2)
condition = condition.view(3, context.size(1) // 3, -1)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)

View File

@ -351,7 +351,7 @@ class ClapTextModelWithProjection(nn.Module):
pooled_output = text_outputs[1]
text_embeds = self.text_projection(pooled_output)
return text_embeds, text_outputs[0]
return text_outputs[0], torch.tensor([]), text_embeds
class ClapTextEncoderModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):

View File

@ -36,20 +36,45 @@ class HunyuanFoleyConditioning(io.ComfyNode):
inputs = [
io.Conditioning.Input("siglip_encoding_1"),
io.Conditioning.Input("synchformer_encoding_2"),
io.Conditioning.Input("text_encoding"),
io.Conditioning.Input("text_encoding_positive"),
io.Conditioning.Input("text_encoding_negative"),
],
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
)
@classmethod
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding):
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative):
text_encoding_positive = text_encoding_positive[0][0]
text_encoding_negative = text_encoding_negative[0][0]
all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)
if isinstance(text_encoding, list):
text_encoding = text_encoding[0]
max_l = max([t.size(1) for t in all_])
max_d = max([t.size(2) for t in all_])
def repeat_shapes(max_value, input, dim = 1):
# temporary repeat values on the cpu
factor_pos, remainder = divmod(max_value, input.shape[dim])
positions = [1] * input.ndim
positions[dim] = factor_pos
input = input.cpu().repeat(*positions)
if remainder > 0:
pad = input[:, :remainder, :]
input = torch.cat([input, pad], dim =1)
return input
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
x = siglip_encoding_1
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
embeds = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding], dim = 0)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return io.NodeOutput(positive, negative)
class FoleyExtension(ComfyExtension):

View File

@ -59,10 +59,11 @@ class EncodeVideo(io.ComfyNode):
raise ValueError("Must either have vae or clip_vision.")
elif vae is None and clip_vision is None:
raise ValueError("Can't have VAE and Clip Vision passed at the same time!")
model = vae.first_stage_model if vae is not None else clip_vision.model
vae = vae if vae is not None else clip_vision
if hasattr(vae.first_stage_model, "video_encoding"):
data, num_segments, output_fn = vae.first_stage_model.video_encoding(video, step_size)
if hasattr(model, "video_encoding"):
data, num_segments, output_fn = model.video_encoding(video, step_size)
batch_size = b * num_segments
else:
data = video.view(batch_size, c, h, w)
@ -77,7 +78,11 @@ class EncodeVideo(io.ComfyNode):
with torch.inference_mode():
for i in range(0, total, batch_size):
chunk = data[i : i + batch_size]
out = vae.encode(chunk)
if hasattr(vae, "encode"):
out = vae.encode(chunk)
else:
out = vae.encode_image(chunk)
out = out["image_embeds"]
outputs.append(out)
del out, chunk
torch.cuda.empty_cache()