This commit is contained in:
Valeriy Pavlovich 2026-03-29 22:27:36 -04:00 committed by GitHub
commit 4a9879052e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,11 +8,14 @@ import av
import io
import itertools
import json
import logging
import numpy as np
import math
import torch
from .._util import VideoContainer, VideoCodec, VideoComponents
logger = logging.getLogger(__name__)
def container_to_output_format(container_format: str | None) -> str | None:
"""
@ -402,6 +405,16 @@ class VideoFromComponents(VideoInput):
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
def mux_packets(container: av.OutputContainer, packets):
if packets is None:
return
if isinstance(packets, (list, tuple)):
for packet in packets:
if packet is not None:
container.mux(packet)
return
container.mux(packets)
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -433,6 +446,8 @@ class VideoFromComponents(VideoInput):
audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
# Guard ffmpeg encoder against invalid upstream audio (NaN/Inf/out-of-range).
waveform = torch.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1.0, 1.0)
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
@ -449,13 +464,26 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
encoded_audio_packets = None
flush_audio_packets = None
try:
audio_np = waveform.float().cpu().contiguous().numpy()
if not np.isfinite(audio_np).all():
audio_np = np.nan_to_num(audio_np, nan=0.0, posinf=0.0, neginf=0.0)
# Flush encoder
output.mux(audio_stream.encode(None))
frame = av.AudioFrame.from_ndarray(audio_np, format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
encoded_audio_packets = audio_stream.encode(frame)
flush_audio_packets = audio_stream.encode(None)
except (av.error.ArgumentError, ValueError, TypeError) as exc:
logger.error(
"Audio encode failed due to invalid audio data; skipping audio track and saving video-only output: %s",
exc,
)
else:
mux_packets(output, encoded_audio_packets)
mux_packets(output, flush_audio_packets)
def as_trimmed(
self,