mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-01 21:02:30 +08:00
Read audio and video at the same time in video loader node. (#13591)
This commit is contained in:
parent
64b8457f55
commit
3cbf015578
@ -12,6 +12,7 @@ import numpy as np
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def container_to_output_format(container_format: str | None) -> str | None:
|
def container_to_output_format(container_format: str | None) -> str | None:
|
||||||
@ -238,32 +239,86 @@ class VideoFromFile(VideoInput):
|
|||||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||||
else:
|
else:
|
||||||
start_time = self.__start_time
|
start_time = self.__start_time
|
||||||
|
|
||||||
# Get video frames
|
# Get video frames
|
||||||
frames = []
|
frames = []
|
||||||
|
audio_frames = []
|
||||||
alphas = None
|
alphas = None
|
||||||
start_pts = int(start_time / video_stream.time_base)
|
start_pts = int(start_time / video_stream.time_base)
|
||||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||||
container.seek(start_pts, stream=video_stream)
|
|
||||||
image_format = 'gbrpf32le'
|
|
||||||
for frame in container.decode(video_stream):
|
|
||||||
if alphas is None:
|
|
||||||
for comp in frame.format.components:
|
|
||||||
if comp.is_alpha:
|
|
||||||
alphas = []
|
|
||||||
image_format = 'gbrapf32le'
|
|
||||||
break
|
|
||||||
|
|
||||||
if frame.pts < start_pts:
|
if start_pts != 0:
|
||||||
continue
|
container.seek(start_pts, stream=video_stream)
|
||||||
if self.__duration and frame.pts >= end_pts:
|
|
||||||
|
image_format = 'gbrpf32le'
|
||||||
|
audio = None
|
||||||
|
|
||||||
|
streams = [video_stream]
|
||||||
|
has_first_audio_frame = False
|
||||||
|
checked_alpha = False
|
||||||
|
|
||||||
|
# Default to False so we decode until EOF if duration is 0
|
||||||
|
video_done = False
|
||||||
|
audio_done = True
|
||||||
|
|
||||||
|
if len(container.streams.audio):
|
||||||
|
audio_stream = container.streams.audio[-1]
|
||||||
|
streams += [audio_stream]
|
||||||
|
resampler = av.audio.resampler.AudioResampler(format='fltp')
|
||||||
|
audio_done = False
|
||||||
|
|
||||||
|
for packet in container.demux(*streams):
|
||||||
|
if video_done and audio_done:
|
||||||
break
|
break
|
||||||
|
|
||||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
if packet.stream.type == "video":
|
||||||
if alphas is None:
|
if video_done:
|
||||||
frames.append(torch.from_numpy(img))
|
continue
|
||||||
else:
|
try:
|
||||||
frames.append(torch.from_numpy(img[..., :-1]))
|
for frame in packet.decode():
|
||||||
alphas.append(torch.from_numpy(img[..., -1:]))
|
if frame.pts < start_pts:
|
||||||
|
continue
|
||||||
|
if self.__duration and frame.pts >= end_pts:
|
||||||
|
video_done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not checked_alpha:
|
||||||
|
for comp in frame.format.components:
|
||||||
|
if comp.is_alpha:
|
||||||
|
alphas = []
|
||||||
|
image_format = 'gbrapf32le'
|
||||||
|
break
|
||||||
|
checked_alpha = True
|
||||||
|
|
||||||
|
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||||
|
if alphas is None:
|
||||||
|
frames.append(torch.from_numpy(img))
|
||||||
|
else:
|
||||||
|
frames.append(torch.from_numpy(img[..., :-1]))
|
||||||
|
alphas.append(torch.from_numpy(img[..., -1:]))
|
||||||
|
except av.error.InvalidDataError:
|
||||||
|
logging.info("pyav decode error")
|
||||||
|
|
||||||
|
elif packet.stream.type == "audio":
|
||||||
|
if audio_done:
|
||||||
|
continue
|
||||||
|
|
||||||
|
aframes = itertools.chain.from_iterable(
|
||||||
|
map(resampler.resample, packet.decode())
|
||||||
|
)
|
||||||
|
for frame in aframes:
|
||||||
|
if self.__duration and frame.time > start_time + self.__duration:
|
||||||
|
audio_done = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_first_audio_frame:
|
||||||
|
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||||
|
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||||
|
if to_skip < frame.samples:
|
||||||
|
has_first_audio_frame = True
|
||||||
|
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||||
|
else:
|
||||||
|
audio_frames.append(frame.to_ndarray())
|
||||||
|
|
||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
|
||||||
if alphas is not None:
|
if alphas is not None:
|
||||||
@ -272,42 +327,16 @@ class VideoFromFile(VideoInput):
|
|||||||
# Get frame rate
|
# Get frame rate
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||||
|
|
||||||
# Get audio if available
|
if len(audio_frames) > 0:
|
||||||
audio = None
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
container.seek(start_pts, stream=video_stream)
|
if self.__duration:
|
||||||
# Use last stream for consistency
|
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||||
if len(container.streams.audio):
|
|
||||||
audio_stream = container.streams.audio[-1]
|
|
||||||
audio_frames = []
|
|
||||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
|
||||||
frames = itertools.chain.from_iterable(
|
|
||||||
map(resample, container.decode(audio_stream))
|
|
||||||
)
|
|
||||||
|
|
||||||
has_first_frame = False
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
for frame in frames:
|
audio = AudioInput({
|
||||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
"waveform": audio_tensor,
|
||||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||||
if to_skip < frame.samples:
|
})
|
||||||
has_first_frame = True
|
|
||||||
break
|
|
||||||
if has_first_frame:
|
|
||||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
|
||||||
|
|
||||||
for frame in frames:
|
|
||||||
if self.__duration and frame.time > start_time + self.__duration:
|
|
||||||
break
|
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
|
||||||
if len(audio_frames) > 0:
|
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
|
||||||
if self.__duration:
|
|
||||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
|
||||||
|
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
|
||||||
audio = AudioInput({
|
|
||||||
"waveform": audio_tensor,
|
|
||||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
metadata = container.metadata
|
metadata = container.metadata
|
||||||
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user