From 8311b156ad802a7840e636c14f624c1595110fc7 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Sat, 27 Sep 2025 23:47:30 +0300 Subject: [PATCH] allowed returning frames --- comfy/ldm/hunyuan_foley/vae.py | 153 +++++++++++++++++++- comfy_api/latest/_input_impl/video_types.py | 15 +- 2 files changed, 163 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index 7c4057072..e691f248c 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -1,15 +1,166 @@ +import math import torch import numpy as np from typing import List +import torch.nn as nn from einops import rearrange from torchvision.transforms import v2 +from torch.nn.utils.parametrizations import weight_norm from comfy.ldm.hunyuan_foley.syncformer import Synchformer -from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder import comfy.ops ops = comfy.ops.disable_weight_init +# until the higgsv2 pr gets accepted +def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs): + return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype)) + + +def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs): + return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype)) + + +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels, device = None, dtype = None): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype)) + + def forward(self, x): + return snake(x, self.alpha) + +class DACResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim, device = device, dtype = dtype), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations), + Snake1d(dim, device = device, dtype = dtype), + WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class DACEncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None): + super().__init__() + self.block = nn.Sequential( + DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations), + DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations), + DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + device = device, dtype = dtype, operations = operations + ), + ) + + def forward(self, x): + return self.block(x) + + +class DACEncoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 256, + device = None, dtype = None, operations = None + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DACDecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim, device = device, dtype = dtype), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=stride % 2, + device = device, dtype = dtype, operations = operations + ), + DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations), + DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations), + DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations), + ) + + def forward(self, x): + return self.block(x) + + +class DACDecoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + device = None, dtype = None, operations = None + ): + super().__init__() + + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )] + + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)] + + layers += [ + Snake1d(output_dim, device = device, dtype = dtype), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + class DAC(torch.nn.Module): def __init__( self, diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index f646504c8..a57f5fd73 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -89,7 +89,7 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") - def get_duration(self) -> float: + def get_duration(self, return_frames=False) -> float: """ Returns the duration of the video in seconds. @@ -100,14 +100,18 @@ class VideoFromFile(VideoInput): self.__file.seek(0) with av.open(self.__file, mode="r") as container: if container.duration is not None: - return float(container.duration / av.time_base) + if not return_frames: + return float(container.duration / av.time_base) # Fallback: calculate from frame count and frame rate video_stream = next( (s for s in container.streams if s.type == "video"), None ) if video_stream and video_stream.frames and video_stream.average_rate: - return float(video_stream.frames / video_stream.average_rate) + length = float(video_stream.frames / video_stream.average_rate) + if return_frames: + return length, float(video_stream.frames) + return length # Last resort: decode frames to count them if video_stream and video_stream.average_rate: @@ -117,7 +121,10 @@ class VideoFromFile(VideoInput): for _ in packet.decode(): frame_count += 1 if frame_count > 0: - return float(frame_count / video_stream.average_rate) + length = float(frame_count / video_stream.average_rate) + if return_frames: + return length, float(frame_count) + return length raise ValueError(f"Could not determine duration for file '{self.__file}'")