mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
work on the conditioning
This commit is contained in:
parent
4241f106dc
commit
663d971830
@ -122,7 +122,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd or "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
|
||||
elif "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
|
||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||
norm_weight = sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0]
|
||||
if norm_weight == 1152:
|
||||
|
||||
@ -635,6 +635,24 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def trim_repeats(expanded):
|
||||
_, L, D = expanded.shape
|
||||
seq = expanded[0]
|
||||
|
||||
repeat_len = L
|
||||
for k in range(1, L // 2 + 1):
|
||||
if torch.equal(seq[:k], seq[k:2*k]):
|
||||
repeat_len = k
|
||||
break
|
||||
|
||||
repeat_dim = D
|
||||
for k in range(1, D // 2 + 1):
|
||||
if torch.equal(seq[:, :k], seq[:, k:2*k]):
|
||||
repeat_dim = k
|
||||
break
|
||||
|
||||
return expanded[:, :repeat_len, :repeat_dim]
|
||||
|
||||
class HunyuanVideoFoley(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -810,18 +828,30 @@ class HunyuanVideoFoley(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
full_cond: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
control = None,
|
||||
transformer_options = {},
|
||||
drop_visual: Optional[List[bool]] = None,
|
||||
):
|
||||
device = x.device
|
||||
audio = x
|
||||
bs, _, ol = x.shape
|
||||
tl = ol // self.patch_size
|
||||
|
||||
condition, uncondition = torch.chunk(2, full_cond)
|
||||
uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition)
|
||||
clip_feat, sync_feat, cond = torch.chunk(3, condition)
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond])
|
||||
condition, uncondition = torch.chunk(context, 2)
|
||||
|
||||
condition = condition.view(3, context.size(1) // 3, -1)
|
||||
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||
|
||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
|
||||
|
||||
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
|
||||
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
|
||||
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
|
||||
if drop_visual is not None:
|
||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||
|
||||
@ -351,7 +351,7 @@ class ClapTextModelWithProjection(nn.Module):
|
||||
pooled_output = text_outputs[1]
|
||||
text_embeds = self.text_projection(pooled_output)
|
||||
|
||||
return text_embeds, text_outputs[0]
|
||||
return text_outputs[0], torch.tensor([]), text_embeds
|
||||
|
||||
class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
|
||||
@ -36,20 +36,45 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
inputs = [
|
||||
io.Conditioning.Input("siglip_encoding_1"),
|
||||
io.Conditioning.Input("synchformer_encoding_2"),
|
||||
io.Conditioning.Input("text_encoding"),
|
||||
io.Conditioning.Input("text_encoding_positive"),
|
||||
io.Conditioning.Input("text_encoding_negative"),
|
||||
],
|
||||
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding):
|
||||
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative):
|
||||
|
||||
text_encoding_positive = text_encoding_positive[0][0]
|
||||
text_encoding_negative = text_encoding_negative[0][0]
|
||||
all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)
|
||||
|
||||
if isinstance(text_encoding, list):
|
||||
text_encoding = text_encoding[0]
|
||||
max_l = max([t.size(1) for t in all_])
|
||||
max_d = max([t.size(2) for t in all_])
|
||||
|
||||
def repeat_shapes(max_value, input, dim = 1):
|
||||
# temporary repeat values on the cpu
|
||||
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
||||
|
||||
positions = [1] * input.ndim
|
||||
positions[dim] = factor_pos
|
||||
input = input.cpu().repeat(*positions)
|
||||
|
||||
if remainder > 0:
|
||||
pad = input[:, :remainder, :]
|
||||
input = torch.cat([input, pad], dim =1)
|
||||
|
||||
return input
|
||||
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
|
||||
|
||||
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
||||
|
||||
x = siglip_encoding_1
|
||||
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||
|
||||
embeds = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding], dim = 0)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
class FoleyExtension(ComfyExtension):
|
||||
|
||||
@ -59,10 +59,11 @@ class EncodeVideo(io.ComfyNode):
|
||||
raise ValueError("Must either have vae or clip_vision.")
|
||||
elif vae is None and clip_vision is None:
|
||||
raise ValueError("Can't have VAE and Clip Vision passed at the same time!")
|
||||
model = vae.first_stage_model if vae is not None else clip_vision.model
|
||||
vae = vae if vae is not None else clip_vision
|
||||
|
||||
if hasattr(vae.first_stage_model, "video_encoding"):
|
||||
data, num_segments, output_fn = vae.first_stage_model.video_encoding(video, step_size)
|
||||
if hasattr(model, "video_encoding"):
|
||||
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||
batch_size = b * num_segments
|
||||
else:
|
||||
data = video.view(batch_size, c, h, w)
|
||||
@ -77,7 +78,11 @@ class EncodeVideo(io.ComfyNode):
|
||||
with torch.inference_mode():
|
||||
for i in range(0, total, batch_size):
|
||||
chunk = data[i : i + batch_size]
|
||||
out = vae.encode(chunk)
|
||||
if hasattr(vae, "encode"):
|
||||
out = vae.encode(chunk)
|
||||
else:
|
||||
out = vae.encode_image(chunk)
|
||||
out = out["image_embeds"]
|
||||
outputs.append(out)
|
||||
del out, chunk
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user