mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 07:22:36 +08:00
work on the conditioning
This commit is contained in:
parent
4241f106dc
commit
663d971830
@ -122,7 +122,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
|
||||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||||
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
|
||||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd or "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
|
elif "vision_model.encoder.layers.11.layer_norm1.weight" in sd:
|
||||||
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
|
||||||
norm_weight = sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0]
|
norm_weight = sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0]
|
||||||
if norm_weight == 1152:
|
if norm_weight == 1152:
|
||||||
|
|||||||
@ -635,6 +635,24 @@ class SingleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def trim_repeats(expanded):
|
||||||
|
_, L, D = expanded.shape
|
||||||
|
seq = expanded[0]
|
||||||
|
|
||||||
|
repeat_len = L
|
||||||
|
for k in range(1, L // 2 + 1):
|
||||||
|
if torch.equal(seq[:k], seq[k:2*k]):
|
||||||
|
repeat_len = k
|
||||||
|
break
|
||||||
|
|
||||||
|
repeat_dim = D
|
||||||
|
for k in range(1, D // 2 + 1):
|
||||||
|
if torch.equal(seq[:, :k], seq[:, k:2*k]):
|
||||||
|
repeat_dim = k
|
||||||
|
break
|
||||||
|
|
||||||
|
return expanded[:, :repeat_len, :repeat_dim]
|
||||||
|
|
||||||
class HunyuanVideoFoley(nn.Module):
|
class HunyuanVideoFoley(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -810,18 +828,30 @@ class HunyuanVideoFoley(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
full_cond: torch.Tensor,
|
context: torch.Tensor,
|
||||||
|
control = None,
|
||||||
transformer_options = {},
|
transformer_options = {},
|
||||||
drop_visual: Optional[List[bool]] = None,
|
drop_visual: Optional[List[bool]] = None,
|
||||||
):
|
):
|
||||||
|
device = x.device
|
||||||
audio = x
|
audio = x
|
||||||
bs, _, ol = x.shape
|
bs, _, ol = x.shape
|
||||||
tl = ol // self.patch_size
|
tl = ol // self.patch_size
|
||||||
|
|
||||||
condition, uncondition = torch.chunk(2, full_cond)
|
condition, uncondition = torch.chunk(context, 2)
|
||||||
uncond_1, uncond_2, uncond_3 = torch.chunk(3, uncondition)
|
|
||||||
clip_feat, sync_feat, cond = torch.chunk(3, condition)
|
condition = condition.view(3, context.size(1) // 3, -1)
|
||||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([uncond_3, cond])
|
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||||
|
|
||||||
|
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||||
|
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||||
|
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
|
||||||
|
|
||||||
|
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
|
||||||
|
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
|
||||||
|
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
|
||||||
|
|
||||||
|
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||||
|
|
||||||
if drop_visual is not None:
|
if drop_visual is not None:
|
||||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||||
|
|||||||
@ -351,7 +351,7 @@ class ClapTextModelWithProjection(nn.Module):
|
|||||||
pooled_output = text_outputs[1]
|
pooled_output = text_outputs[1]
|
||||||
text_embeds = self.text_projection(pooled_output)
|
text_embeds = self.text_projection(pooled_output)
|
||||||
|
|
||||||
return text_embeds, text_outputs[0]
|
return text_outputs[0], torch.tensor([]), text_embeds
|
||||||
|
|
||||||
class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
class ClapTextEncoderModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
|||||||
@ -36,20 +36,45 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
|||||||
inputs = [
|
inputs = [
|
||||||
io.Conditioning.Input("siglip_encoding_1"),
|
io.Conditioning.Input("siglip_encoding_1"),
|
||||||
io.Conditioning.Input("synchformer_encoding_2"),
|
io.Conditioning.Input("synchformer_encoding_2"),
|
||||||
io.Conditioning.Input("text_encoding"),
|
io.Conditioning.Input("text_encoding_positive"),
|
||||||
|
io.Conditioning.Input("text_encoding_negative"),
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding):
|
def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative):
|
||||||
|
|
||||||
|
text_encoding_positive = text_encoding_positive[0][0]
|
||||||
|
text_encoding_negative = text_encoding_negative[0][0]
|
||||||
|
all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)
|
||||||
|
|
||||||
if isinstance(text_encoding, list):
|
max_l = max([t.size(1) for t in all_])
|
||||||
text_encoding = text_encoding[0]
|
max_d = max([t.size(2) for t in all_])
|
||||||
|
|
||||||
|
def repeat_shapes(max_value, input, dim = 1):
|
||||||
|
# temporary repeat values on the cpu
|
||||||
|
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
||||||
|
|
||||||
|
positions = [1] * input.ndim
|
||||||
|
positions[dim] = factor_pos
|
||||||
|
input = input.cpu().repeat(*positions)
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
pad = input[:, :remainder, :]
|
||||||
|
input = torch.cat([input, pad], dim =1)
|
||||||
|
|
||||||
|
return input
|
||||||
|
|
||||||
|
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
||||||
|
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
|
||||||
|
|
||||||
|
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
||||||
|
|
||||||
|
x = siglip_encoding_1
|
||||||
|
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||||
|
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||||
|
|
||||||
embeds = torch.cat([siglip_encoding_1, synchformer_encoding_2, text_encoding], dim = 0)
|
|
||||||
positive = [[embeds, {}]]
|
|
||||||
negative = [[torch.zeros_like(embeds), {}]]
|
|
||||||
return io.NodeOutput(positive, negative)
|
return io.NodeOutput(positive, negative)
|
||||||
|
|
||||||
class FoleyExtension(ComfyExtension):
|
class FoleyExtension(ComfyExtension):
|
||||||
|
|||||||
@ -59,10 +59,11 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
raise ValueError("Must either have vae or clip_vision.")
|
raise ValueError("Must either have vae or clip_vision.")
|
||||||
elif vae is None and clip_vision is None:
|
elif vae is None and clip_vision is None:
|
||||||
raise ValueError("Can't have VAE and Clip Vision passed at the same time!")
|
raise ValueError("Can't have VAE and Clip Vision passed at the same time!")
|
||||||
|
model = vae.first_stage_model if vae is not None else clip_vision.model
|
||||||
vae = vae if vae is not None else clip_vision
|
vae = vae if vae is not None else clip_vision
|
||||||
|
|
||||||
if hasattr(vae.first_stage_model, "video_encoding"):
|
if hasattr(model, "video_encoding"):
|
||||||
data, num_segments, output_fn = vae.first_stage_model.video_encoding(video, step_size)
|
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||||
batch_size = b * num_segments
|
batch_size = b * num_segments
|
||||||
else:
|
else:
|
||||||
data = video.view(batch_size, c, h, w)
|
data = video.view(batch_size, c, h, w)
|
||||||
@ -77,7 +78,11 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
for i in range(0, total, batch_size):
|
for i in range(0, total, batch_size):
|
||||||
chunk = data[i : i + batch_size]
|
chunk = data[i : i + batch_size]
|
||||||
out = vae.encode(chunk)
|
if hasattr(vae, "encode"):
|
||||||
|
out = vae.encode(chunk)
|
||||||
|
else:
|
||||||
|
out = vae.encode_image(chunk)
|
||||||
|
out = out["image_embeds"]
|
||||||
outputs.append(out)
|
outputs.append(out)
|
||||||
del out, chunk
|
del out, chunk
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user