mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 18:27:40 +08:00
allowed returning frames
This commit is contained in:
parent
f85e1cf1b9
commit
8311b156ad
@ -1,15 +1,166 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from torchvision.transforms import v2
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
|
||||
from comfy.ldm.hunyuan_foley.syncformer import Synchformer
|
||||
from comfy.ldm.higgsv2.tokenizer import DACEncoder, DACDecoder
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
# until the higgsv2 pr gets accepted
|
||||
def WNConv1d(*args, device = None, dtype = None, operations = None, **kwargs):
|
||||
return weight_norm(operations.Conv1d(*args, **kwargs, device = device, dtype = dtype))
|
||||
|
||||
|
||||
def WNConvTranspose1d(*args, device = None, dtype = None, operations = None, **kwargs):
|
||||
return weight_norm(operations.ConvTranspose1d(*args, **kwargs, device = device, dtype = dtype))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def snake(x, alpha):
|
||||
shape = x.shape
|
||||
x = x.reshape(shape[0], shape[1], -1)
|
||||
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||
x = x.reshape(shape)
|
||||
return x
|
||||
|
||||
|
||||
class Snake1d(nn.Module):
|
||||
def __init__(self, channels, device = None, dtype = None):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.ones(1, channels, 1, device = device, dtype = dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return snake(x, self.alpha)
|
||||
|
||||
class DACResidualUnit(nn.Module):
|
||||
def __init__(self, dim: int = 16, dilation: int = 1, device = None, dtype = None, operations = None):
|
||||
super().__init__()
|
||||
pad = ((7 - 1) * dilation) // 2
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(dim, device = device, dtype = dtype),
|
||||
WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad, device = device, dtype = dtype, operations = operations),
|
||||
Snake1d(dim, device = device, dtype = dtype),
|
||||
WNConv1d(dim, dim, kernel_size=1, device = device, dtype = dtype, operations = operations),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.block(x)
|
||||
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||
if pad > 0:
|
||||
x = x[..., pad:-pad]
|
||||
return x + y
|
||||
|
||||
|
||||
class DACEncoderBlock(nn.Module):
|
||||
def __init__(self, dim: int = 16, stride: int = 1, device = None, dtype = None, operations = None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
DACResidualUnit(dim // 2, dilation=1, device = device, dtype = dtype, operations = operations),
|
||||
DACResidualUnit(dim // 2, dilation=3, device = device, dtype = dtype, operations = operations),
|
||||
DACResidualUnit(dim // 2, dilation=9, device = device, dtype = dtype, operations = operations),
|
||||
Snake1d(dim // 2),
|
||||
WNConv1d(
|
||||
dim // 2,
|
||||
dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
device = device, dtype = dtype, operations = operations
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DACEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 64,
|
||||
strides: list = [2, 4, 8, 8],
|
||||
d_latent: int = 256,
|
||||
device = None, dtype = None, operations = None
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
# Create last convolution
|
||||
self.block += [
|
||||
Snake1d(d_model),
|
||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device = device, dtype = dtype, operations = operations),
|
||||
]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DACDecoderBlock(nn.Module):
|
||||
def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, device = None, dtype = None, operations = None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
Snake1d(input_dim, device = device, dtype = dtype),
|
||||
WNConvTranspose1d(
|
||||
input_dim,
|
||||
output_dim,
|
||||
kernel_size=2 * stride,
|
||||
stride=stride,
|
||||
padding=math.ceil(stride / 2),
|
||||
output_padding=stride % 2,
|
||||
device = device, dtype = dtype, operations = operations
|
||||
),
|
||||
DACResidualUnit(output_dim, dilation=1, device = device, dtype = dtype, operations = operations),
|
||||
DACResidualUnit(output_dim, dilation=3, device = device, dtype = dtype, operations = operations),
|
||||
DACResidualUnit(output_dim, dilation=9, device = device, dtype = dtype, operations = operations),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
class DACDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_channel,
|
||||
channels,
|
||||
rates,
|
||||
d_out: int = 1,
|
||||
device = None, dtype = None, operations = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations )]
|
||||
|
||||
for i, stride in enumerate(rates):
|
||||
input_dim = channels // 2**i
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
layers += [
|
||||
Snake1d(output_dim, device = device, dtype = dtype),
|
||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
class DAC(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -89,7 +89,7 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_duration(self) -> float:
|
||||
def get_duration(self, return_frames=False) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -100,14 +100,18 @@ class VideoFromFile(VideoInput):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
if container.duration is not None:
|
||||
return float(container.duration / av.time_base)
|
||||
if not return_frames:
|
||||
return float(container.duration / av.time_base)
|
||||
|
||||
# Fallback: calculate from frame count and frame rate
|
||||
video_stream = next(
|
||||
(s for s in container.streams if s.type == "video"), None
|
||||
)
|
||||
if video_stream and video_stream.frames and video_stream.average_rate:
|
||||
return float(video_stream.frames / video_stream.average_rate)
|
||||
length = float(video_stream.frames / video_stream.average_rate)
|
||||
if return_frames:
|
||||
return length, float(video_stream.frames)
|
||||
return length
|
||||
|
||||
# Last resort: decode frames to count them
|
||||
if video_stream and video_stream.average_rate:
|
||||
@ -117,7 +121,10 @@ class VideoFromFile(VideoInput):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
length = float(frame_count / video_stream.average_rate)
|
||||
if return_frames:
|
||||
return length, float(frame_count)
|
||||
return length
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user