Fix video save audio encode failures for invalid waveform values

This commit is contained in:
vp 2026-03-27 16:35:56 +03:00
parent 225c52f6a4
commit 462592a359

View File

@ -8,11 +8,14 @@ import av
import io
import itertools
import json
import logging
import numpy as np
import math
import torch
from .._util import VideoContainer, VideoCodec, VideoComponents
logger = logging.getLogger(__name__)
def container_to_output_format(container_format: str | None) -> str | None:
"""
@ -402,6 +405,16 @@ class VideoFromComponents(VideoInput):
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
def mux_packets(container: av.OutputContainer, packets):
if packets is None:
return
if isinstance(packets, (list, tuple)):
for packet in packets:
if packet is not None:
container.mux(packet)
return
container.mux(packets)
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -433,6 +446,8 @@ class VideoFromComponents(VideoInput):
audio_sample_rate = int(self.__components.audio['sample_rate'])
waveform = self.__components.audio['waveform']
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
# Guard ffmpeg encoder against invalid upstream audio (NaN/Inf/out-of-range).
waveform = torch.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0).clamp(-1.0, 1.0)
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
@ -449,13 +464,22 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
try:
audio_np = waveform.float().cpu().contiguous().numpy()
if not np.isfinite(audio_np).all():
audio_np = np.nan_to_num(audio_np, nan=0.0, posinf=0.0, neginf=0.0)
# Flush encoder
output.mux(audio_stream.encode(None))
frame = av.AudioFrame.from_ndarray(audio_np, format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
mux_packets(output, audio_stream.encode(frame))
# Flush encoder
mux_packets(output, audio_stream.encode(None))
except Exception as exc:
logger.warning(
"Failed to encode audio track, saving video-only output: %s", exc
)
def as_trimmed(
self,