fixed multiple errors in nodes and model loading

This commit is contained in:
Yousef Rafat 2025-09-29 22:44:40 +03:00
parent a6dabd2855
commit 42a265cddf
9 changed files with 143 additions and 102 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Optional, Union, Dict
from typing import List, Tuple, Optional, Union
from functools import partial
import math
@ -638,17 +638,19 @@ class SingleStreamBlock(nn.Module):
class HunyuanVideoFoley(nn.Module):
def __init__(
self,
model_args,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
operations = None
operations = None,
**kwargs
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dtype = dtype
self.depth_triple_blocks = 18
self.depth_single_blocks = 36
model_args = {}
self.interleaved_audio_visual_rope = model_args.get("interleaved_audio_visual_rope", True)

View File

@ -850,8 +850,8 @@ class GlobalTransformer(torch.nn.Module):
self.vis_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
self.aud_in_lnorm = operations.LayerNorm(n_embd, **factory_kwargs)
# aux tokens
self.OFF_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
self.MOD_tok = operations.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
self.OFF_tok = nn.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
self.MOD_tok = nn.Parameter(torch.randn(1, 1, n_embd, **factory_kwargs))
# whole token dropout
self.tok_pdrop = tok_pdrop
self.tok_drop_vis = torch.nn.Dropout1d(tok_pdrop)
@ -863,7 +863,7 @@ class GlobalTransformer(torch.nn.Module):
)
# the stem
self.drop = torch.nn.Dropout(embd_pdrop)
self.blocks = operations.Sequential(*[Block(self.config, operations=operations, **factory_kwargs) for _ in range(n_layer)])
self.blocks = nn.Sequential(*[Block(self.config, operations=operations, **factory_kwargs) for _ in range(n_layer)])
# pre-output norm
self.ln_f = operations.LayerNorm(n_embd)
# maybe add a head

View File

@ -5,7 +5,7 @@ from typing import List
import torch.nn as nn
from einops import rearrange
from torchvision.transforms import v2
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils import weight_norm
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
@ -154,6 +154,7 @@ class DACDecoder(nn.Module):
layers += [
Snake1d(output_dim, device = device, dtype = dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
@ -164,11 +165,11 @@ class DACDecoder(nn.Module):
class DAC(torch.nn.Module):
def __init__(
self,
encoder_dim: int = 64,
encoder_rates: List[int] = [2, 4, 8, 8],
latent_dim: int = None,
decoder_dim: int = 1536,
decoder_rates: List[int] = [8, 8, 4, 2],
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3],
sample_rate: int = 44100,
):
super().__init__()
@ -204,6 +205,7 @@ class DAC(torch.nn.Module):
class FoleyVae(torch.nn.Module):
def __init__(self):
super().__init__()
self.dac = DAC()
self.syncformer = Synchformer(None, None, operations = ops)
self.syncformer_preprocess = v2.Compose(

View File

@ -422,7 +422,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}triple_blocks.17.audio_cross_q.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Foley
return {}
dit_config = {}
dit_config["image_model"] = "hunyuan_foley"
return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape

View File

@ -508,7 +508,10 @@ class VAE:
self.latent_dim = 128
self.first_stage_model = comfy.ldm.hunyuan_foley.vae.FoleyVae()
# TODO
self.memory_used_encode = lambda shape, dtype: shape[0] * model_management.dtype_size(dtype)
encode_layers = 25
decode_layers = 4
self.memory_used_encode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * encode_layers
self.memory_used_decode = lambda shape, dtype: torch.prod(shape) * model_management.dtype_size(dtype) * decode_layers
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)

View File

@ -66,7 +66,6 @@ class ClapTextEmbeddings(nn.Module):
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=device), persistent=True
)
# End copy
self.padding_idx = pad_token_id
self.position_embeddings = operations.Embedding(
max_position_embeddings, hidden_size, padding_idx=self.padding_idx, device=device, dtype=dtype
@ -145,6 +144,7 @@ class ClapTextSelfAttention(nn.Module):
value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
query_states, key_states, value_states = [t.contiguous() for t in (query_states, key_states, value_states)]
attention_mask = attention_mask.to(query_states.dtype)
attn_output = optimized_attention(query_states, key_states, value_states, self.num_attention_heads, mask = attention_mask, skip_output_reshape=True, skip_reshape=True)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output.reshape(*input_shape, -1).contiguous()
@ -271,16 +271,16 @@ class ClapTextModel(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
embeds: Optional[torch.Tensor] = None,
):
if input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
elif embeds is not None:
input_shape = embeds.size()[:-1]
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
device = input_ids.device if input_ids is not None else embeds.device
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
@ -294,7 +294,7 @@ class ClapTextModel(nn.Module):
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
inputs_embeds=embeds,
)
encoder_outputs = self.encoder(
embedding_output,
@ -308,6 +308,10 @@ class ClapTextModel(nn.Module):
class ClapTextModelWithProjection(nn.Module):
def __init__(
self,
config,
dtype=None,
device=None,
operations=None,
hidden_size: int = 768,
intermediate_size: int = 3072,
layer_norm_eps: float = 1e-12,
@ -318,26 +322,30 @@ class ClapTextModelWithProjection(nn.Module):
type_vocab_size: int = 1,
vocab_size: int = 50265,
pad_token_id: int = 1,
device=None,
dtype=None,
operations=None
):
super().__init__()
self.num_layers = num_hidden_layers
self.text_model = ClapTextModel(num_attention_heads, vocab_size, hidden_size, intermediate_size, pad_token_id, max_position_embeddings,
type_vocab_size, layer_norm_eps, num_hidden_layers, device=device, dtype=dtype, operations=operations)
self.text_projection = ClapProjectionLayer(hidden_size, projection_dim, device=device, dtype=dtype, operations=operations,)
def get_input_embeddings(self):
return self.text_model.embeddings.word_embeddings
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
embeds = None,
**kwargs
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
embeds=embeds
)
pooled_output = text_outputs[1]
@ -347,9 +355,10 @@ class ClapTextModelWithProjection(nn.Module):
class ClapTextEncoderModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
self.dtypes = set([dtype])
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 1}, layer_norm_hidden_state=False, model_class=ClapTextModelWithProjection, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ClapLargeTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clap_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='clap_l', tokenizer_class=AutoTokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=1, tokenizer_data=tokenizer_data)

View File

@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self, return_frames=False) -> float:
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
@ -100,8 +100,7 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
if not return_frames:
return float(container.duration / av.time_base)
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
@ -109,8 +108,6 @@ class VideoFromFile(VideoInput):
)
if video_stream and video_stream.frames and video_stream.average_rate:
length = float(video_stream.frames / video_stream.average_rate)
if return_frames:
return length, float(video_stream.frames)
return length
# Last resort: decode frames to count them
@ -122,8 +119,6 @@ class VideoFromFile(VideoInput):
frame_count += 1
if frame_count > 0:
length = float(frame_count / video_stream.average_rate)
if return_frames:
return length, float(frame_count)
return length
raise ValueError(f"Could not determine duration for file '{self.__file}'")

View File

@ -1,53 +1,60 @@
import torch
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class EmptyLatentHunyuanFoley:
class EmptyLatentHunyuanFoley(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"length": ("INT", {"default": 12, "min": 1, "max": 15, "tooltip": "The length of the audio. The same length as the video."}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent audios in the batch."}),
},
"optional": {"video": ("VIDEO")}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, length, batch_size, video = None):
def define_schema(cls):
return io.Schema(
node_id="EmptyLatentHunyuanFoley",
display_name="EmptyLatentHunyuanFoley",
category="audio/latent",
inputs = [
io.Int.Input("length", min = 1, max = 15, default = 12),
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
io.Video.Input("video", optional=True),
],
outputs=[io.Latent.Output(display_name="latent")]
)
@classmethod
def execute(cls, length, batch_size, video = None):
if video is not None:
_, length = video.get_duration(return_frames = True)
length = video.size(0)
length /= 25
shape = (batch_size, 128, int(50 * length))
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan_foley"}, )
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
class HunyuanFoleyConditioning:
class HunyuanFoleyConditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"video_encoding_siglip": ("CONDITIONING",),
"video_encoding_synchformer": ("CONDITIONING",),
"text_encoding": ("CONDITIONING",)
},
}
def define_schema(cls):
return io.Schema(
node_id="HunyuanFoleyConditioning",
display_name="HunyuanFoleyConditioning",
category="conditioning/video_models",
inputs = [
io.Conditioning.Input("video_encoding_1"),
io.Conditioning.Input("video_encoding_2"),
io.Conditioning.Input("text_encoding"),
],
outputs=[io.Conditioning.Output(display_name= "positive"), io.Conditioning.Output(display_name="negative")]
)
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, video_encoding_1, video_encoding_2, text_encoding):
@classmethod
def execute(cls, video_encoding_1, video_encoding_2, text_encoding):
embeds = torch.cat([video_encoding_1, video_encoding_2, text_encoding], dim = 0)
positive = [[embeds, {}]]
negative = [[torch.zeros_like(embeds), {}]]
return (positive, negative)
return io.NodeOutput(positive, negative)
NODE_CLASS_MAPPINGS = {
"HunyuanFoleyConditioning": HunyuanFoleyConditioning,
"EmptyLatentHunyuanFoley": EmptyLatentHunyuanFoley,
}
class FoleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
HunyuanFoleyConditioning,
EmptyLatentHunyuanFoley
]
async def comfy_entrypoint() -> FoleyExtension:
return FoleyExtension()

View File

@ -6,7 +6,6 @@ import av
import torch
import folder_paths
import json
import numpy as np
from typing import Optional
from typing_extensions import override
from fractions import Fraction
@ -50,15 +49,18 @@ class EncodeVideo(io.ComfyNode):
@classmethod
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
b, t, c, h, w = video.shape
t, c, h, w = video.shape
b = 1
batch_size = b * t
if vae is None and clip_vision is None:
if vae is not None and clip_vision is not None:
raise ValueError("Must either have vae or clip_vision.")
elif vae is None and clip_vision is None:
raise ValueError("Can't have VAE and Clip Vision passed at the same time!")
vae = vae if vae is not None else clip_vision
if hasattr(vae.first_stage_model, "video_encoding"):
data, num_segments, output_fn = vae.video_encoding(video, step_size)
data, num_segments, output_fn = vae.first_stage_model.video_encoding(video, step_size)
batch_size = b * num_segments
else:
data = video.view(batch_size, c, h, w)
@ -76,7 +78,7 @@ class EncodeVideo(io.ComfyNode):
output = torch.cat(outputs)
return output_fn(output)
return io.NodeOutput(output_fn(output))
class ResampleVideo(io.ComfyNode):
@classmethod
@ -87,44 +89,62 @@ class ResampleVideo(io.ComfyNode):
category="image/video",
inputs = [
io.Video.Input("video"),
io.Int.Input("target_fps")
io.Int.Input("target_fps", min=1, default=25)
],
outputs=[io.Image.Output(display_name="images")]
outputs=[io.Video.Output(display_name="video")]
)
@classmethod
def execute(cls, container: av.container.InputContainer, target_fps: int):
def execute(cls, video, target_fps: int):
# doesn't support upsampling
stream = container.streams.video[0]
frames = []
with av.open(video.get_stream_source(), mode="r") as container:
stream = container.streams.video[0]
frames = []
src_rate = stream.average_rate or stream.guessed_rate
src_fps = float(src_rate) if src_rate else None
src_rate = stream.average_rate or stream.guessed_rate
src_fps = float(src_rate) if src_rate else None
# yield original frames if asked for upsampling or src is unknown
if src_fps is None or target_fps > src_fps:
for packet in container.demux(stream):
for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
return torch.stack(frames)
stream.thread_type = "AUTO"
next_time = 0.0
step = 1.0 / target_fps
# yield original frames if asked for upsampling or src is unknown
if src_fps is None or target_fps > src_fps:
for packet in container.demux(stream):
for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
return torch.stack(frames)
if frame.time is None:
continue
t = frame.time
while t >= next_time:
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
next_time += step
stream.thread_type = "AUTO"
return io.NodeOutput(torch.stack(frames))
next_time = 0.0
step = 1.0 / target_fps
class VideoToImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VideoToImage",
category="image/video",
display_name = "Video To Images",
inputs=[io.Video.Input("video")],
outputs=[io.Image.Output("images")]
)
@classmethod
def execute(cls, video):
with av.open(video.get_stream_source(), mode="r") as container:
components = video.get_components_internal(container)
for packet in container.demux(stream):
for frame in packet.decode():
if frame.time is None:
continue
t = frame.time
while t >= next_time:
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
frames.append(arr)
next_time += step
return torch.stack(frames)
images = components.images
return io.NodeOutput(images)
class SaveWEBM(io.ComfyNode):
@classmethod
@ -325,7 +345,8 @@ class VideoExtension(ComfyExtension):
GetVideoComponents,
LoadVideo,
EncodeVideo,
ResampleVideo
ResampleVideo,
VideoToImage
]
async def comfy_entrypoint() -> VideoExtension: