This commit is contained in:
Valeriy Pavlovich 2026-03-28 20:41:53 +03:00 committed by GitHub
commit dc5014f0c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,11 +8,14 @@ import av
import io import io
import itertools import itertools
import json import json
import logging
import numpy as np import numpy as np
import math import math
import torch import torch
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
logger = logging.getLogger(__name__)
def container_to_output_format(container_format: str | None) -> str | None: def container_to_output_format(container_format: str | None) -> str | None:
""" """
@ -402,6 +405,16 @@ class VideoFromComponents(VideoInput):
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
"""Save the video to a file path or BytesIO buffer.""" """Save the video to a file path or BytesIO buffer."""
def mux_packets(container: av.OutputContainer, packets):
if packets is None:
return
if isinstance(packets, (list, tuple)):
for packet in packets:
if packet is not None:
container.mux(packet)
return
container.mux(packets)
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -433,6 +446,8 @@ class VideoFromComponents(VideoInput):
audio_sample_rate = int(self.__components.audio['sample_rate']) audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform'] waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
# Guard ffmpeg encoder against invalid upstream audio (NaN/Inf/out-of-range).
waveform = torch.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1.0, 1.0)
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout) audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
@ -449,13 +464,26 @@ class VideoFromComponents(VideoInput):
output.mux(packet) output.mux(packet)
if audio_stream and self.__components.audio: if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout) encoded_audio_packets = None
frame.sample_rate = audio_sample_rate flush_audio_packets = None
frame.pts = 0 try:
output.mux(audio_stream.encode(frame)) audio_np = waveform.float().cpu().contiguous().numpy()
if not np.isfinite(audio_np).all():
audio_np = np.nan_to_num(audio_np, nan=0.0, posinf=0.0, neginf=0.0)
# Flush encoder frame = av.AudioFrame.from_ndarray(audio_np, format='fltp', layout=layout)
output.mux(audio_stream.encode(None)) frame.sample_rate = audio_sample_rate
frame.pts = 0
encoded_audio_packets = audio_stream.encode(frame)
flush_audio_packets = audio_stream.encode(None)
except (av.error.ArgumentError, ValueError, TypeError) as exc:
logger.error(
"Audio encode failed due to invalid audio data; skipping audio track and saving video-only output: %s",
exc,
)
else:
mux_packets(output, encoded_audio_packets)
mux_packets(output, flush_audio_packets)
def as_trimmed( def as_trimmed(
self, self,