allowed returning frames

This commit is contained in:
Yousef Rafat 2025-09-27 23:47:30 +03:00
parent f85e1cf1b9
commit 8311b156ad
2 changed files with 163 additions and 5 deletions

View File

@ -1,15 +1,166 @@
import math
import torch import torch
import numpy as np import numpy as np
from typing import List from typing import List
import torch.nn as nn
from einops import rearrange from einops import rearrange
from torchvision.transforms import v2 from torchvision.transforms import v2
from torch.nn.utils.parametrizations import weight_norm
from comfy.ldm.hunyuan_foley.syncformer import Synchformer from comfy.ldm.hunyuan_foley.syncformer import Synchformer
from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
# until the higgsv2 pr gets accepted
def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype))
def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype))
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels, device = None, dtype = None):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype))
def forward(self, x):
return snake(x, self.alpha)
class DACResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations),
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class DACEncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
device = device, dtype = dtype, operations = operations
),
)
def forward(self, x):
return self.block(x)
class DACEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 256,
device = None, dtype = None, operations = None
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DACDecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim, device = device, dtype = dtype),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
device = device, dtype = dtype, operations = operations
),
DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
return self.block(x)
class DACDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
device = None, dtype = None, operations = None
):
super().__init__()
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )]
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
layers += [
Snake1d(output_dim, device = device, dtype = dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(torch.nn.Module): class DAC(torch.nn.Module):
def __init__( def __init__(
self, self,

View File

@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float: def get_duration(self, return_frames=False) -> float:
""" """
Returns the duration of the video in seconds. Returns the duration of the video in seconds.
@ -100,14 +100,18 @@ class VideoFromFile(VideoInput):
self.__file.seek(0) self.__file.seek(0)
with av.open(self.__file, mode="r") as container: with av.open(self.__file, mode="r") as container:
if container.duration is not None: if container.duration is not None:
return float(container.duration / av.time_base) if not return_frames:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate # Fallback: calculate from frame count and frame rate
video_stream = next( video_stream = next(
(s for s in container.streams if s.type == "video"), None (s for s in container.streams if s.type == "video"), None
) )
if video_stream and video_stream.frames and video_stream.average_rate: if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate) length = float(video_stream.frames / video_stream.average_rate)
if return_frames:
return length, float(video_stream.frames)
return length
# Last resort: decode frames to count them # Last resort: decode frames to count them
if video_stream and video_stream.average_rate: if video_stream and video_stream.average_rate:
@ -117,7 +121,10 @@ class VideoFromFile(VideoInput):
for _ in packet.decode(): for _ in packet.decode():
frame_count += 1 frame_count += 1
if frame_count > 0: if frame_count > 0:
return float(frame_count / video_stream.average_rate) length = float(frame_count / video_stream.average_rate)
if return_frames:
return length, float(frame_count)
return length
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")