allowed returning frames

This commit is contained in:
Yousef Rafat 2025-09-27 23:47:30 +03:00
parent f85e1cf1b9
commit 8311b156ad
2 changed files with 163 additions and 5 deletions

View File

@ -1,15 +1,166 @@
import math
import torch
import numpy as np
from typing import List
import torch.nn as nn
from einops import rearrange
from torchvision.transforms import v2
from torch.nn.utils.parametrizations import weight_norm
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder
import comfy.ops
ops = comfy.ops.disable_weight_init
# until the higgsv2 pr gets accepted
def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype))
def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs):
return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype))
@torch.jit.script
def snake(x, alpha):
shape = x.shape
x = x.reshape(shape[0], shape[1], -1)
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
x = x.reshape(shape)
return x
class Snake1d(nn.Module):
def __init__(self, channels, device = None, dtype = None):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype))
def forward(self, x):
return snake(x, self.alpha)
class DACResidualUnit(nn.Module):
def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.block = nn.Sequential(
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations),
Snake1d(dim, device = device, dtype = dtype),
WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
y = self.block(x)
pad = (x.shape[-1] - y.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + y
class DACEncoderBlock(nn.Module):
def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations),
Snake1d(dim // 2),
WNConv1d(
dim // 2,
dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
device = device, dtype = dtype, operations = operations
),
)
def forward(self, x):
return self.block(x)
class DACEncoder(nn.Module):
def __init__(
self,
d_model: int = 64,
strides: list = [2, 4, 8, 8],
d_latent: int = 256,
device = None, dtype = None, operations = None
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
# Create last convolution
self.block += [
Snake1d(d_model),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations),
]
# Wrap black into nn.Sequential
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
def forward(self, x):
return self.block(x)
class DACDecoderBlock(nn.Module):
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None):
super().__init__()
self.block = nn.Sequential(
Snake1d(input_dim, device = device, dtype = dtype),
WNConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=stride % 2,
device = device, dtype = dtype, operations = operations
),
DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations),
DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations),
)
def forward(self, x):
return self.block(x)
class DACDecoder(nn.Module):
def __init__(
self,
input_channel,
channels,
rates,
d_out: int = 1,
device = None, dtype = None, operations = None
):
super().__init__()
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )]
for i, stride in enumerate(rates):
input_dim = channels // 2**i
output_dim = channels // 2 ** (i + 1)
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
layers += [
Snake1d(output_dim, device = device, dtype = dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class DAC(torch.nn.Module):
def __init__(
self,

View File

@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
def get_duration(self, return_frames=False) -> float:
"""
Returns the duration of the video in seconds.
@ -100,14 +100,18 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
if not return_frames:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
length = float(video_stream.frames / video_stream.average_rate)
if return_frames:
return length, float(video_stream.frames)
return length
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
@ -117,7 +121,10 @@ class VideoFromFile(VideoInput):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
length = float(frame_count / video_stream.average_rate)
if return_frames:
return length, float(frame_count)
return length
raise ValueError(f"Could not determine duration for file '{self.__file}'")