diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 6ed41bba8..9a107fb76 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -251,6 +251,7 @@ class VideoFromFile(VideoInput): container.seek(start_pts, stream=video_stream) image_format = 'gbrpf32le' + process_image_format = lambda a: a audio = None streams = [video_stream] @@ -283,11 +284,25 @@ class VideoFromFile(VideoInput): break if not checked_alpha: + alpha_channel = False for comp in frame.format.components: if comp.is_alpha or frame.format.name == "pal8": alphas = [] - image_format = 'gbrapf32le' + alpha_channel = True break + if frame.format.name in ("yuvj420p", "rgb24", "rgba", "pal8"): + process_image_format = lambda a: a.float() / 255.0 + if alpha_channel: + image_format = 'rgba' + else: + image_format = 'rgb24' + else: + process_image_format = lambda a: a + if alpha_channel: + image_format = 'gbrapf32le' + else: + image_format = 'gbrpf32le' + checked_alpha = True img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) @@ -323,9 +338,9 @@ class VideoFromFile(VideoInput): else: audio_frames.append(frame.to_ndarray()) - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) + images = process_image_format(torch.stack(frames)) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) if alphas is not None: - alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) + alphas = process_image_format(torch.stack(alphas)) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)