Merge branch 'Comfy-Org:master' into master

This commit is contained in:
azazeal04 2026-06-15 13:12:02 +02:00 committed by GitHub
commit f61ad29a21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 314 additions and 24 deletions

View File

@ -382,11 +382,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips ### AMD ROCm Tips
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default. You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes # Notes

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0: if x.shape[1] == 0:
return x return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2) return apply_rope1(x, freqs_cis)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

View File

@ -27,10 +27,13 @@ class VideoInput(ABC):
path: Union[str, IO[bytes]], path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
""" """
Abstract method to save the video input to a file. Abstract method to save the video input to a file.
bit_depth selects the encoded bit depth; None keeps the video's native depth.
""" """
pass pass
@ -83,6 +86,14 @@ class VideoInput(ABC):
components = self.get_components() components = self.get_components()
return components.images.shape[2], components.images.shape[1] return components.images.shape[2], components.images.shape[1]
def get_bit_depth(self) -> int:
"""
Returns the bit depth of the video (e.g. 8 or 10).
Default implementation returns 8; subclasses report their real depth.
"""
return 8
def get_duration(self) -> float: def get_duration(self) -> float:
""" """
Returns the duration of the video in seconds. Returns the duration of the video in seconds.

View File

@ -52,6 +52,12 @@ def get_open_write_kwargs(
return open_kwargs return open_kwargs
def video_stream_bit_depth(stream) -> int:
if stream is None or stream.format is None or not stream.format.components:
return 8
return max(component.bits for component in stream.format.components)
class VideoFromFile(VideoInput): class VideoFromFile(VideoInput):
""" """
Class representing video input from a file. Class representing video input from a file.
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'") raise ValueError(f"No video stream found in file '{self.__file}'")
def get_bit_depth(self) -> int:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode="r") as container:
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
return video_stream_bit_depth(video_stream)
def get_duration(self) -> float: def get_duration(self) -> float:
""" """
Returns the duration of the video in seconds. Returns the duration of the video in seconds.
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
image_format = 'gbrpf32le' image_format = 'gbrpf32le'
process_image_format = lambda a: a process_image_format = lambda a: a
align_graph = None
audio = None audio = None
streams = [video_stream] streams = [video_stream]
@ -310,7 +324,24 @@ class VideoFromFile(VideoInput):
checked_alpha = True checked_alpha = True
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) # Fix non-deterministic video decode when the video width is not a multiple of 32
# For non-yuvj pixel formats (all H.264/H.265 video)
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
if align_graph is None:
pad_w = ((frame.width + 31) // 32) * 32
g = av.filter.Graph()
g_src = g.add_buffer(width=frame.width, height=frame.height,
format=frame.format.name, time_base=video_stream.time_base)
g_pad = g.add('pad', f'{pad_w}:{frame.height}:0:0')
g_sink = g.add('buffersink')
g_src.link_to(g_pad)
g_pad.link_to(g_sink)
g.configure()
align_graph = (g, g_src, g_sink)
align_graph[1].push(frame)
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:, :frame.width])
else:
img = frame.to_ndarray(format=image_format)
if frame.rotation != 0: if frame.rotation != 0:
k = int(round(frame.rotation // 90)) k = int(round(frame.rotation // 90))
img = np.rot90(img, k=k, axes=(0, 1)).copy() img = np.rot90(img, k=k, axes=(0, 1)).copy()
@ -377,25 +408,32 @@ class VideoFromFile(VideoInput):
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container: with av.open(self.__file, mode='r') as container:
container_format = container.format.name container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
video_encoding = video_stream.codec.name if video_stream is not None else None
source_bit_depth = video_stream_bit_depth(video_stream)
reuse_streams = True reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","): if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False reuse_streams = False
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
reuse_streams = False
if self.__start_time or self.__duration: if self.__start_time or self.__duration:
reuse_streams = False reuse_streams = False
if not reuse_streams: if not reuse_streams:
if bit_depth is None:
bit_depth = source_bit_depth
components = self.get_components_internal(container) components = self.get_components_internal(container)
video = VideoFromComponents(components) video = VideoFromComponents(components)
return video.save_to( return video.save_to(
path, format=format, codec=codec, metadata=metadata path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
) )
streams = container.streams streams = container.streams
@ -451,8 +489,10 @@ class VideoFromComponents(VideoInput):
Class representing video input from tensors. Class representing video input from tensors.
""" """
def __init__(self, components: VideoComponents): def __init__(self, components: VideoComponents, bit_depth: int = 8):
self.__components = components self.__components = components
# Tensor components have no inherent bit depth; this is the depth used when encoding.
self.__bit_depth = bit_depth
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
return VideoComponents( return VideoComponents(
@ -461,18 +501,26 @@ class VideoFromComponents(VideoInput):
frame_rate=self.__components.frame_rate, frame_rate=self.__components.frame_rate,
) )
def get_bit_depth(self) -> int:
return self.__bit_depth
def save_to( def save_to(
self, self,
path: str, path: str,
format: VideoContainer = VideoContainer.AUTO, format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO, codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
bit_depth: int | None = None,
): ):
"""Save the video to a file path or BytesIO buffer.""" """Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4: if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now") raise ValueError("Only H264 codec is supported for now")
# None means "use the depth this video was created with" (CreateVideo's choice).
if bit_depth is None:
bit_depth = self.__bit_depth
is_10bit = bit_depth >= 10
extra_kwargs = {} extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value extra_kwargs["format"] = format.value
@ -488,10 +536,11 @@ class VideoFromComponents(VideoInput):
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream # Create a video stream
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
video_stream = output.add_stream('h264', rate=frame_rate) video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2] video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1] video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p' video_stream.pix_fmt = pix_fmt
# Create an audio stream # Create an audio stream
audio_sample_rate = 1 audio_sample_rate = 1
@ -505,9 +554,14 @@ class VideoFromComponents(VideoInput):
# Encode video # Encode video
for i, frame in enumerate(self.__components.images): for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) if is_10bit:
frame = av.VideoFrame.from_ndarray(img, format='rgb24') # 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
else:
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format=pix_fmt)
packet = video_stream.encode(frame) packet = video_stream.encode(frame)
output.mux(packet) output.mux(packet)

View File

@ -208,6 +208,10 @@ class TripoMultiviewToModelRequest(BaseModel):
quad: bool | None = Field(False, description="Whether to apply quad to the generated model") quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoTexturePrompt(BaseModel):
text: str | None = Field(None, description="Text guidance for texture generation")
class TripoTextureModelRequest(BaseModel): class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task") type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model") original_model_task_id: str = Field(..., description="The task ID of the original model")
@ -219,6 +223,11 @@ class TripoTextureModelRequest(BaseModel):
texture_alignment: TripoTextureAlignment | None = Field( texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method" TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
) )
texture_prompt: TripoTexturePrompt | None = Field(
None,
description="Optional guidance for texturing. Required in practice for imported models, "
"which carry no source image to infer texture from.",
)
class TripoRefineModelRequest(BaseModel): class TripoRefineModelRequest(BaseModel):
@ -307,6 +316,17 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
orientation: str | None = None orientation: str | None = None
class TripoImportModelRequest(BaseModel):
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
"""
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
class TripoTaskOutput(BaseModel): class TripoTaskOutput(BaseModel):
model: str | None = Field(None, description="URL to the model") model: str | None = Field(None, description="URL to the model")
base_model: str | None = Field(None, description="URL to the base model") base_model: str | None = Field(None, description="URL to the base model")

View File

@ -1,6 +1,6 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.tripo import ( from comfy_api_nodes.apis.tripo import (
TripoAnimateRetargetRequest, TripoAnimateRetargetRequest,
TripoAnimateRigRequest, TripoAnimateRigRequest,
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.tripo import (
TripoFileEmptyReference, TripoFileEmptyReference,
TripoFileReference, TripoFileReference,
TripoImageToModelRequest, TripoImageToModelRequest,
TripoImportModelRequest,
TripoModelVersion, TripoModelVersion,
TripoMultiviewToModelRequest, TripoMultiviewToModelRequest,
TripoOrientation, TripoOrientation,
@ -21,6 +22,7 @@ from comfy_api_nodes.apis.tripo import (
TripoTaskType, TripoTaskType,
TripoTextToModelRequest, TripoTextToModelRequest,
TripoTextureModelRequest, TripoTextureModelRequest,
TripoTexturePrompt,
TripoUrlReference, TripoUrlReference,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
@ -28,6 +30,7 @@ from comfy_api_nodes.util import (
download_url_to_file_3d, download_url_to_file_3d,
poll_op, poll_op,
sync_op, sync_op,
upload_3d_model_to_comfyapi,
upload_images_to_comfyapi, upload_images_to_comfyapi,
) )
@ -538,6 +541,14 @@ class TripoTextureNode(IO.ComfyNode):
optional=True, optional=True,
advanced=True, advanced=True,
), ),
IO.String.Input(
"texture_prompt",
default="",
multiline=True,
optional=True,
tooltip="Optional text guidance for texturing. Required in practice for imported "
"models (Tripo: Import Model), which carry no source image to infer colors from.",
),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -571,6 +582,7 @@ class TripoTextureNode(IO.ComfyNode):
texture_seed: int | None = None, texture_seed: int | None = None,
texture_quality: str | None = None, texture_quality: str | None = None,
texture_alignment: str | None = None, texture_alignment: str | None = None,
texture_prompt: str = "",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
response = await sync_op( response = await sync_op(
cls, cls,
@ -583,6 +595,7 @@ class TripoTextureNode(IO.ComfyNode):
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
), ),
) )
return await poll_until_finished(cls, response, average_duration=80) return await poll_until_finished(cls, response, average_duration=80)
@ -915,6 +928,90 @@ class TripoConversionNode(IO.ComfyNode):
return await poll_until_finished(cls, response, average_duration=30) return await poll_until_finished(cls, response, average_duration=30)
class TripoImportModelNode(IO.ComfyNode):
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoImportModelNode",
display_name="Tripo: Import Model",
category="partner/3d/Tripo",
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
"GLB is recommended: textures survive import only when embedded in the file. "
"Note that texturing an imported model requires a texture prompt.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
"OBJ and STL files carry no embedded textures.",
),
],
outputs=[
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"text","text":"Free"}""",
),
)
@classmethod
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
file_format = (model_3d.format or "").lstrip(".").lower()
if file_format == "gltf":
raise ValueError(
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
)
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported 3D format '{file_format or 'unknown'}'. "
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
)
size = len(model_3d.get_bytes())
if size > 150 * 1024 * 1024:
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
response_model=TripoTaskResponse,
data=TripoImportModelRequest(url=url, format=file_format),
)
if response.code != 0:
raise RuntimeError(f"Failed to import model: {response.error}")
task_id = response.data.task_id
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
response_model=TripoTaskResponse,
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
progress_extractor=lambda x: x.data.progress,
estimated_duration=10,
)
if response_poll.data.status != TripoTaskStatus.SUCCESS:
raise RuntimeError(f"Failed to import model: {response_poll}")
return IO.NodeOutput(task_id)
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str: def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
return ( return (
"(" "("
@ -1292,6 +1389,7 @@ class TripoExtension(ComfyExtension):
TripoP1TextToModelNode, TripoP1TextToModelNode,
TripoP1ImageToModelNode, TripoP1ImageToModelNode,
TripoP1MultiviewToModelNode, TripoP1MultiviewToModelNode,
TripoImportModelNode,
TripoTextureNode, TripoTextureNode,
TripoRefineNode, TripoRefineNode,
TripoRigNode, TripoRigNode,

View File

@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
io.Image.Input("images", tooltip="The images to create a video from."), io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
io.Int.Input(
"bit_depth",
min=8,
max=10,
default=8,
step=2,
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
" banding, but some players and downstream nodes may not support it.",
optional=True,
display_mode=io.NumberDisplay.number,
),
], ],
outputs=[ outputs=[
io.Video.Output(), io.Video.Output(),
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: def execute(
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
) -> io.NodeOutput:
return io.NodeOutput( return io.NodeOutput(
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) InputImpl.VideoFromComponents(
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
bit_depth=bit_depth,
)
) )
class GetVideoComponents(io.ComfyNode): class GetVideoComponents(io.ComfyNode):
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
search_aliases=["extract frames", "split video", "video to images", "demux"], search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components", display_name="Get Video Components",
category="video", category="video",
description="Extracts all components from a video: frames, audio, and framerate.", description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
inputs=[ inputs=[
io.Video.Input("video", tooltip="The video to extract components from."), io.Video.Input("video", tooltip="The video to extract components from."),
], ],
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
io.Image.Output(display_name="images"), io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"), io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"), io.Float.Output(display_name="fps"),
io.Int.Output(display_name="bit_depth"),
], ],
) )
@classmethod @classmethod
def execute(cls, video: Input.Video) -> io.NodeOutput: def execute(cls, video: Input.Video) -> io.NodeOutput:
components = video.get_components() components = video.get_components()
return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
class LoadVideo(io.ComfyNode): class LoadVideo(io.ComfyNode):

View File

@ -1 +1 @@
comfyui_manager==4.2.1 comfyui_manager==4.2.2

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.45.15 comfyui-frontend-package==1.45.15
comfyui-workflow-templates==0.9.98 comfyui-workflow-templates==0.9.98
comfyui-embedded-docs==0.5.3 comfyui-embedded-docs==0.5.4
torch torch
torchsde torchsde
torchvision torchvision

View File

@ -27,6 +27,7 @@ import logging
import mimetypes import mimetypes
from comfy.cli_args import args from comfy.cli_args import args
from comfy.deploy_environment import get_deploy_environment
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy_api import feature_flags from comfy_api import feature_flags
@ -690,6 +691,7 @@ class PromptServer():
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version, "pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"deploy_environment": get_deploy_environment(),
"argv": sys.argv "argv": sys.argv
}, },
"devices": device_entries "devices": device_entries

View File

@ -0,0 +1,93 @@
import pytest
import torch
import av
import numpy as np
from fractions import Fraction
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.latest._util.video_types import VideoComponents
@pytest.fixture(scope="module")
def gradient_components():
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
width, height, frames = 64, 64, 3
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
@pytest.fixture(scope="module")
def src8(gradient_components, tmp_path_factory):
"""8-bit h264 mp4 (Create Video default)"""
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
VideoFromComponents(gradient_components).save_to(path)
return path
@pytest.fixture(scope="module")
def src10(gradient_components, tmp_path_factory):
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
return path
def probe(path):
"""(codec, pix_fmt, bit_depth) of the first video stream"""
with av.open(path) as container:
stream = container.streams.video[0]
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
def decoded_levels(path):
"""Unique tonal levels in the first decoded frame (banding measure)"""
with av.open(path) as container:
frame = next(container.decode(container.streams.video[0]))
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
def video_packet_bytes(path):
"""Raw video packet payloads; identical to the source's only for a true remux"""
with av.open(path) as container:
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
def test_create_video_bit_depth(src8, src10):
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
assert probe(src8) == ("h264", "yuv420p", 8)
assert probe(src10) == ("h264", "yuv420p10le", 10)
assert decoded_levels(src10) > 2 * decoded_levels(src8)
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
for name, src in [("p8", src8), ("p10", src10)]:
path = str(tmp_path / f"{name}.mp4")
VideoFromFile(src).save_to(path)
assert probe(path) == probe(src)
assert video_packet_bytes(path) == video_packet_bytes(src)
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
down = str(tmp_path / "down8.mp4")
VideoFromFile(src10).save_to(down, bit_depth=8)
assert probe(down) == ("h264", "yuv420p", 8)
up = str(tmp_path / "up10.mp4")
VideoFromFile(src8).save_to(up, bit_depth=10)
assert probe(up) == ("h264", "yuv420p10le", 10)
def test_trim_keeps_source_depth(src10, tmp_path):
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
path = str(tmp_path / "trim.mp4")
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
assert probe(path) == ("h264", "yuv420p10le", 10)
def test_get_bit_depth(gradient_components, src8, src10):
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
assert VideoFromFile(src8).get_bit_depth() == 8
assert VideoFromFile(src10).get_bit_depth() == 10
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
assert VideoFromComponents(gradient_components).get_bit_depth() == 8