diff --git a/comfy/clip_model.py b/comfy/clip_model.py index b0294ca23..9f0b7943b 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -246,7 +246,7 @@ class CLIPVision(torch.nn.Module): x = self.post_layernorm(x) if self.use_head: pooled_output = self.head(x) - else: + else: pooled_output = x else: pooled_output = self.post_layernorm(x[:, 0, :]) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 78843da47..2f31f71ef 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -19,6 +19,8 @@ class Output: def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"): image = image[:, :, :, :3] if image.shape[3] > 3 else image + if image.dtype == torch.uint8: + image = image.float() / 255.0 mean = torch.tensor(mean, device=image.device, dtype=image.dtype) std = torch.tensor(std, device=image.device, dtype=image.dtype) image = image.movedim(-1, 1) diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index 90ef95179..2e1088781 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -55,7 +55,7 @@ class TimestepEmbedder(TimestepEmbedderParent): def forward(self, t): t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) t_emb = self.mlp(t_freq) - return t_emb + return t_emb class SwiGLU(nn.Module): def __init__(self, dim: int, hidden_dim: int, device, dtype, operations): @@ -150,9 +150,9 @@ class ChannelLastConv1d(nn.Module): self.register_parameter("bias", underlying.bias) else: self.register_parameter("bias", None) - + object.__setattr__(self, "_underlying", underlying) - + def forward(self, x: torch.Tensor) -> torch.Tensor: self._underlying = self._underlying.to(x.dtype) x = self._underlying(x.permute(0, 2, 1)) @@ -204,7 +204,7 @@ class ModulateDiT(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(self.act(x)) - + class FinalLayer1D(nn.Module): def __init__(self, hidden_size, patch_size, out_channels, device=None, dtype=None, operations = None): factory_kwargs = {"device": device, "dtype": dtype} @@ -223,7 +223,7 @@ class FinalLayer1D(nn.Module): self.linear = self.linear.to(x.dtype) x = self.linear(x) return x - + class MLP(nn.Module): def __init__( self, @@ -254,7 +254,7 @@ class MLP(nn.Module): self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): - return self.drop2(self.fc2(self.norm(self.drop1(self.act(self.fc1(x)))))) + return self.drop2(self.fc2(self.norm(self.drop1(self.act(self.fc1(x)))))) def _to_tuple(x, dim=2): @@ -297,7 +297,7 @@ def get_meshgrid_nd(start, *args, dim=2): def get_nd_rotary_pos_embed( rope_dim_list, start, *args, theta=10000.0, use_real=False, theta_rescale_factor=1.0, freq_scaling=1.0 ): - + grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) embs = [] @@ -411,14 +411,14 @@ class TwoStreamCABlock(nn.Module): self.max_text_len = 100 self.rope_dim_list = None - + self.audio_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.v_cond_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) self.audio_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.v_cond_cross_q = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.text_cross_kv = operations.Linear(hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs) - + self.audio_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) self.v_cond_cross_proj = operations.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs) @@ -435,11 +435,11 @@ class TwoStreamCABlock(nn.Module): def build_rope_for_text(self, text_len, head_dim, rope_dim_list=None): target_ndim = 1 # n-d RoPE rope_sizes = [text_len] - + if rope_dim_list is None: rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" - + text_freqs_cos, text_freqs_sin = get_nd_rotary_pos_embed( rope_dim_list=rope_dim_list, start=rope_sizes, @@ -461,7 +461,7 @@ class TwoStreamCABlock(nn.Module): sync_vec: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, + (audio_mod1_shift, audio_mod1_scale, audio_mod1_gate, audio_mod2_shift, audio_mod2_scale, audio_mod2_gate, audio_mod3_shift, audio_mod3_scale, audio_mod3_gate, ) = self.audio_mod(sync_vec if sync_vec is not None else vec).chunk(9, dim=-1) @@ -477,19 +477,19 @@ class TwoStreamCABlock(nn.Module): v_cond_mod3_scale, v_cond_mod3_gate, ) = self.v_cond_mod(vec).chunk(9, dim=-1) - + audio_q, audio_k, audio_v = prepare_self_attn_qkv( - audio, self.audio_norm1, self.audio_self_attn_qkv, + audio, self.audio_norm1, self.audio_self_attn_qkv, self.audio_self_q_norm, self.audio_self_k_norm, audio_mod1_shift, audio_mod1_scale, self.num_heads ) v_cond_q, v_cond_k, v_cond_v = prepare_self_attn_qkv( - v_cond, self.v_cond_norm1, self.v_cond_attn_qkv, + v_cond, self.v_cond_norm1, self.v_cond_attn_qkv, self.v_cond_attn_q_norm, self.v_cond_attn_k_norm, v_cond_mod1_shift, v_cond_mod1_scale, self.num_heads ) - + # Apply RoPE if needed for audio and visual if freqs_cis is not None: if not self.interleaved_audio_visual_rope: @@ -515,18 +515,18 @@ class TwoStreamCABlock(nn.Module): if v_freqs_cis is not None and not self.interleaved_audio_visual_rope: v_cond_qq, v_cond_kk = apply_rotary_emb(v_cond_q, v_cond_k, v_freqs_cis, head_first=False) v_cond_q, v_cond_k = v_cond_qq, v_cond_kk - + q = torch.cat((v_cond_q, audio_q), dim=1) k = torch.cat((v_cond_k, audio_k), dim=1) v = torch.cat((v_cond_v, audio_v), dim=1) - + # TODO: look further into here if attention.__name__ == "attention_pytorch": q, k, v = [t.transpose(1, 2) for t in (q, k, v)] - + attn = attention(q, k, v, heads = self.num_heads, mask=attn_mask, skip_reshape=True) v_cond_attn, audio_attn = torch.split(attn, [v_cond.shape[1], audio.shape[1]], dim=1) - + audio = audio + apply_gate(self.audio_self_proj(audio_attn), gate=audio_mod1_gate) v_cond = v_cond + apply_gate(self.v_cond_self_proj(v_cond_attn), gate=v_cond_mod1_gate) head_dim = self.hidden_size // self.num_heads @@ -544,12 +544,12 @@ class TwoStreamCABlock(nn.Module): text_k = self.text_cross_k_norm(text_k).to(text_v) text_len = text_k.shape[1] - - text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, + + text_freqs_cos, text_freqs_sin = self.build_rope_for_text(text_len, head_dim, rope_dim_list=self.rope_dim_list) text_freqs_cis = (text_freqs_cos.to(text_k.device), text_freqs_sin.to(text_k.device)) text_k = apply_rotary_emb(text_k, text_k, text_freqs_cis, head_first=False)[1] - + v_cond_audio_q = torch.cat([v_cond_q, audio_q], dim=1) if attention.__name__ == "attention_pytorch": @@ -557,7 +557,7 @@ class TwoStreamCABlock(nn.Module): cross_attn = attention(v_cond_audio_q, text_k, text_v, self.num_heads, skip_reshape = True) v_cond_cross_attn, audio_cross_attn = torch.split(cross_attn, [v_cond.shape[1], audio.shape[1]], dim=1) - + audio = audio + apply_gate(self.audio_cross_proj(audio_cross_attn), gate=audio_mod2_gate) v_cond = v_cond + apply_gate(self.v_cond_cross_proj(v_cond_cross_attn), gate=v_cond_mod2_gate) @@ -565,7 +565,7 @@ class TwoStreamCABlock(nn.Module): v_cond = apply_modulated_block(v_cond, self.v_cond_norm3, v_cond_mod3_shift, v_cond_mod3_scale, self.v_cond_mlp, v_cond_mod3_gate) return audio, cond, v_cond - + def prepare_modulated_query(self, x, norm_layer, q_layer, q_norm_layer, shift, scale, num_heads, rope_dim_list): x_mod = modulate(norm_layer(x), shift=shift, scale=scale) @@ -577,9 +577,9 @@ class TwoStreamCABlock(nn.Module): head_dim = q.shape[-1] freqs_cos, freqs_sin = self.build_rope_for_text(q.shape[1], head_dim, rope_dim_list) freqs_cis = (freqs_cos.to(q.device), freqs_sin.to(q.device)) - + q = apply_rotary_emb(q, q, freqs_cis, head_first=False)[0] - + return q class SingleStreamBlock(nn.Module): @@ -697,7 +697,7 @@ class HunyuanVideoFoley(nn.Module): self.patch_size = model_args.get("patch_size", 1) self.visual_in_channels = model_args.get("clip_dim", 768) self.audio_vae_latent_dim = model_args.get("audio_vae_latent_dim", 128) - self.out_channels = self.audio_vae_latent_dim + self.out_channels = self.audio_vae_latent_dim self.unpatchify_channels = self.out_channels self.num_heads = model_args.get("num_heads", 12) @@ -873,7 +873,7 @@ class HunyuanVideoFoley(nn.Module): uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] - + uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] diff = cond_pos.shape[1] - cond_neg.shape[1] @@ -885,6 +885,8 @@ class HunyuanVideoFoley(nn.Module): clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) clip_feat = clip_feat.view(2, -1, 768) + self.conditions = (clip_feat, sync_feat, cond) + else: clip_feat, sync_feat, cond = self.conditions @@ -944,7 +946,7 @@ class HunyuanVideoFoley(nn.Module): else: audio, cond, v_cond = block(*triple_block_args) - x = audio + x = audio if sync_vec is not None: vec = vec.unsqueeze(1).repeat(1, cond_seq_len + v_cond_seq_len, 1) vec = torch.cat((vec, sync_vec), dim=1) diff --git a/comfy/ldm/hunyuan_foley/syncformer.py b/comfy/ldm/hunyuan_foley/syncformer.py index 8340bfedf..27bc7ee04 100644 --- a/comfy/ldm/hunyuan_foley/syncformer.py +++ b/comfy/ldm/hunyuan_foley/syncformer.py @@ -160,7 +160,7 @@ class MotionFormer(nn.Module): def __init__(self, device = None, dtype = None, operations = None): super().__init__() self.APPROX_ATTN_TYPE = "none" - self.APPROX_ATTN_DIM = 64 + self.APPROX_ATTN_DIM = 64 self.img_size = 224 self.patch_size = 16 self.in_chans = 3 @@ -224,7 +224,7 @@ class MotionFormer(nn.Module): self.norm = norm_layer(self.embed_dim) self.pre_logits = nn.Identity() - + transf_enc_layer_kwargs = dict( d_model=self.embed_dim, nhead=self.num_heads, @@ -273,7 +273,7 @@ class MotionFormer(nn.Module): ) return x, tok_mask - + def forward(self, x): B, S, C, T, H, W = x.shape @@ -322,7 +322,7 @@ class BaseEncoderLayer(TransformerEncoderComfyv): device = None, dtype = None, operations = None, *args, **kwargs - ): + ): factory_kwargs = {"device": device, "dtype": dtype} super().__init__(operations = operations, *args, **kwargs, **factory_kwargs) @@ -382,7 +382,7 @@ class SpatialTransformerEncoderLayer(BaseEncoderLayer): x = rearrange(x, "(BS t) D -> BS t D", BS=BS, t=t) return x - + class AST(torch.nn.Module): def __init__( self, @@ -391,7 +391,7 @@ class AST(torch.nn.Module): max_segments: int = None, device = None, dtype = None, operations = None ) -> None: - + super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.extract_features = True @@ -518,7 +518,7 @@ class FrequencyTransformerEncoderLayer(BaseEncoderLayer): x = x.view(BS, t, D) return x - + class ASTEmbeddings(nn.Module): def __init__(self, config: ASTConfig, device = None, dtype = None, operations = None) -> None: @@ -789,7 +789,7 @@ class ASTModel(nn.Module): ), tok_mask, ) - + class ASTMLPHead(nn.Module): def __init__(self, config: ASTConfig, device, dtype, operations): super().__init__() @@ -957,6 +957,7 @@ class Synchformer(nn.Module): ) def forward(self, vis): + vis = vis.to(next(self.parameters()).dtype) vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W) vis = self.vfeat_extractor(vis) return vis diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index 68635bec6..7c634bce0 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -221,10 +221,11 @@ class FoleyVae(torch.nn.Module): def encode(self, x): x = x.to(next(self.parameters()).device) return self.synchformer(x) - + def video_encoding(self, video, step): + video = video.to(torch.uint8) video = torch.stack([self.syncformer_preprocess(t) for t in video]) - + t, c, h, w = video.shape seg_len = 16 t = video.size(0) @@ -233,12 +234,13 @@ class FoleyVae(torch.nn.Module): video = video.contiguous() stride_t, stride_c, stride_h, stride_w = video.stride() - # no copies + # no copies data = video.as_strided( size=(nseg, seg_len, c, h, w), stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w), - ) + ).contiguous() data = data.unsqueeze(0) # b data = rearrange(data, "b s t c h w -> (b s) 1 t c h w") + data = data.float() return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=1) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 939c63571..d3dae54ca 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1121,7 +1121,7 @@ class MultiheadAttentionComfyv(nn.Module): self.batch_first = batch_first self.head_dim = embed_dim // num_heads self.embed_dim = embed_dim - + # overwriting state dict loading to convert in_proj_weight/bias -> self._q_proj/_k_proj/_v_proj def _load_from_state_dict( self, @@ -1164,26 +1164,17 @@ class MultiheadAttentionComfyv(nn.Module): error_msgs, ) - def forward(self, src, k = None, v = None, attn_mask = None, key_padding_mask = None): + def forward(self, src, k=None, v=None, attn_mask=None, key_padding_mask=None): + self._q_proj, self._k_proj, self._v_proj = [ + t.to(src.device).to(src.dtype) + for t in (self._q_proj, self._k_proj, self._v_proj) + ] - self._q_proj, self._k_proj, self._v_proj = [t.to(src.device).to(src.dtype) for t in (self._q_proj, self._k_proj, self._v_proj)] q = self._q_proj(src) - if k is None: - k = self._k_proj(src) - if v is None: - v = self._v_proj(src) - k, v = k.to(src.device).to(src.dtype), v.to(src.device).to(src.dtype) + k = self._k_proj(src if k is None else k.to(src.device).to(src.dtype)) + v = self._v_proj(src if v is None else v.to(src.device).to(src.dtype)) - if k is v: - if q is k: - q = k = v = q.transpose(1, 0) - else: - q, k = (x.transpose(1, 0) for x in (q, k)) - v = k - else: - q, k, v = (x.transpose(1, 0) for x in (q, k, v)) - - output = optimized_attention(q, k, v, self.num_heads, mask = attn_mask) + output = optimized_attention(q, k, v, self.num_heads, mask=attn_mask) return self.out_proj(output) # comfyui implementation of nn.TransformerEncoderLayer diff --git a/comfy/model_base.py b/comfy/model_base.py index 4f58de83c..1523d2efa 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1413,7 +1413,7 @@ class ACEStep(BaseModel): out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype)) out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out - + class HunyuanFoley(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.hunyuan_foley.model.HunyuanVideoFoley): super().__init__(model_config, model_type, device, unet_model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d29534fb0..b42d8d47f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -420,7 +420,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] return dit_config - + if '{}triple_blocks.17.audio_cross_q.weight'.format(key_prefix) in state_dict_keys: # Hunyuan Foley dit_config = {} dit_config["image_model"] = "hunyuan_foley" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1691731a8..ec4b37f06 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1303,7 +1303,7 @@ class Omnigen2(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) - + class HunyuanFoley(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_foley", @@ -1318,7 +1318,7 @@ class HunyuanFoley(supported_models_base.BASE): return model_base.HunyuanFoley(self, device=device) def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.clap_model.ClapLargeTokenizer, comfy.text_encoders.clap_model.ClapTextEncoderModel) - + def process_clip_state_dict(self, state_dict): state_dict = utils.state_dict_prefix_replace(state_dict, {k: "transformer." for k in self.text_encoder_key_prefix}, filter_keys=True) state_dict["logit_scale"] = torch.tensor(1.0) diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py index 9057eaa4b..af914d9bf 100644 --- a/comfy_extras/nodes_hunyuan_foley.py +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -92,7 +92,7 @@ class HunyuanFoleyConditioning(io.ComfyNode): @classmethod def execute(cls, siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative): - + text_encoding_positive = text_encoding_positive[0][0] text_encoding_negative = text_encoding_negative[0][0] all_ = (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative) @@ -108,7 +108,7 @@ class HunyuanFoleyConditioning(io.ComfyNode): # temporary repeat values on the cpu factor_pos, remainder = divmod(max_value, input.shape[dim]) - positions = [1] * input.ndim + positions = [1] * input.ndim positions[dim] = factor_pos input = input.cpu().repeat(*positions) @@ -120,7 +120,7 @@ class HunyuanFoleyConditioning(io.ComfyNode): input = torch.cat([input, pad], dim = dim) return input - + siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_] siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in (siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)] diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 3702cb659..509558537 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -48,7 +48,7 @@ class EncodeVideo(io.ComfyNode): io.Conditioning.Output(display_name="encoded_video"), ], ) - + @classmethod def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None): @@ -94,13 +94,15 @@ class EncodeVideo(io.ComfyNode): chunk = chunk.to(model_dtype) if hasattr(vae, "encode"): try: + if chunk.ndim > 5: + raise ValueError("chunk.ndim > 5") chunk = chunk.movedim(1, -1) out = vae.encode(chunk) - except: + except Exception: out = model.encode(chunk) else: chunk = chunk.movedim(1, -1) - out = vae.encode_image(chunk, crop=False, resize_mode="bilinear") + out = vae.encode_image(chunk.to(torch.uint8), crop=False, resize_mode="bilinear") out = out["image_embeds"] out_cpu = out.cpu() @@ -133,14 +135,14 @@ class ResampleVideo(io.ComfyNode): ) @classmethod def execute(cls, video, target_fps: int): - # doesn't support upsampling + # doesn't support upsampling with av.open(video.get_stream_source(), mode="r") as container: stream = container.streams.video[0] frames = [] src_rate = stream.average_rate or stream.guessed_rate src_fps = float(src_rate) if src_rate else None - + if src_fps is None: logging.warning("src_fps for video resampling is None.")