diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index eb4d3701d..812b3eb30 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -12,6 +12,7 @@ import numpy as np import math import torch from .._util import VideoContainer, VideoCodec, VideoComponents +import logging def container_to_output_format(container_format: str | None) -> str | None: @@ -238,32 +239,86 @@ class VideoFromFile(VideoInput): start_time = max(self._get_raw_duration() + self.__start_time, 0) else: start_time = self.__start_time + # Get video frames frames = [] + audio_frames = [] alphas = None start_pts = int(start_time / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base) - container.seek(start_pts, stream=video_stream) - image_format = 'gbrpf32le' - for frame in container.decode(video_stream): - if alphas is None: - for comp in frame.format.components: - if comp.is_alpha: - alphas = [] - image_format = 'gbrapf32le' - break - if frame.pts < start_pts: - continue - if self.__duration and frame.pts >= end_pts: + if start_pts != 0: + container.seek(start_pts, stream=video_stream) + + image_format = 'gbrpf32le' + audio = None + + streams = [video_stream] + has_first_audio_frame = False + checked_alpha = False + + # Default to False so we decode until EOF if duration is 0 + video_done = False + audio_done = True + + if len(container.streams.audio): + audio_stream = container.streams.audio[-1] + streams += [audio_stream] + resampler = av.audio.resampler.AudioResampler(format='fltp') + audio_done = False + + for packet in container.demux(*streams): + if video_done and audio_done: break - img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) - if alphas is None: - frames.append(torch.from_numpy(img)) - else: - frames.append(torch.from_numpy(img[..., :-1])) - alphas.append(torch.from_numpy(img[..., -1:])) + if packet.stream.type == "video": + if video_done: + continue + try: + for frame in packet.decode(): + if frame.pts < start_pts: + continue + if self.__duration and frame.pts >= end_pts: + video_done = True + break + + if not checked_alpha: + for comp in frame.format.components: + if comp.is_alpha: + alphas = [] + image_format = 'gbrapf32le' + break + checked_alpha = True + + img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if alphas is None: + frames.append(torch.from_numpy(img)) + else: + frames.append(torch.from_numpy(img[..., :-1])) + alphas.append(torch.from_numpy(img[..., -1:])) + except av.error.InvalidDataError: + logging.info("pyav decode error") + + elif packet.stream.type == "audio": + if audio_done: + continue + + aframes = itertools.chain.from_iterable( + map(resampler.resample, packet.decode()) + ) + for frame in aframes: + if self.__duration and frame.time > start_time + self.__duration: + audio_done = True + break + + if not has_first_audio_frame: + offset_seconds = start_time - frame.pts * audio_stream.time_base + to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) + if to_skip < frame.samples: + has_first_audio_frame = True + audio_frames.append(frame.to_ndarray()[..., to_skip:]) + else: + audio_frames.append(frame.to_ndarray()) images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) if alphas is not None: @@ -272,42 +327,16 @@ class VideoFromFile(VideoInput): # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) - # Get audio if available - audio = None - container.seek(start_pts, stream=video_stream) - # Use last stream for consistency - if len(container.streams.audio): - audio_stream = container.streams.audio[-1] - audio_frames = [] - resample = av.audio.resampler.AudioResampler(format='fltp').resample - frames = itertools.chain.from_iterable( - map(resample, container.decode(audio_stream)) - ) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + if self.__duration: + audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] - has_first_frame = False - for frame in frames: - offset_seconds = start_time - frame.pts * audio_stream.time_base - to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) - if to_skip < frame.samples: - has_first_frame = True - break - if has_first_frame: - audio_frames.append(frame.to_ndarray()[..., to_skip:]) - - for frame in frames: - if self.__duration and frame.time > start_time + self.__duration: - break - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - if self.__duration: - audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] - - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, - }) + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, + }) metadata = container.metadata return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)