From b10a61615cd3c64affff3ae62d35905c9a66712a Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Fri, 29 May 2026 13:42:17 +0800 Subject: [PATCH 01/32] chore: update workflow templates to v0.9.91 (#14163) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 651315cb2..0617667e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.85 +comfyui-workflow-templates==0.9.91 comfyui-embedded-docs==0.5.1 torch torchsde From e7214d78eef4c87cd042bc29ec322ad6a2d1509b Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Fri, 29 May 2026 03:06:00 -0400 Subject: [PATCH 02/32] feat: add model_info output to Load3D node (#14144) --- comfy_api/latest/_io.py | 13 +++++++++++++ comfy_extras/nodes_load_3d.py | 4 +++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fed8dc7f0..19d8176b0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -777,6 +777,17 @@ class Load3DCamera(ComfyTypeIO): Type = CameraInfo +@comfytype(io_type="LOAD3D_MODEL_INFO") +class Load3DModelInfo(ComfyTypeIO): + class Model3DTransform(TypedDict): + # Coordinate system: right-handed, Y-up, world space + position: dict[str, float | int] # scene units + quaternion: dict[str, float | int] # normalized, dimensionless; world rotation + scale: dict[str, float | int] # dimensionless multiplier + + Type = list[Model3DTransform] + + @comfytype(io_type="LOAD_3D") class Load3D(ComfyTypeIO): """3D models are stored as a dictionary.""" @@ -786,6 +797,7 @@ class Load3D(ComfyTypeIO): normal: str camera_info: Load3DCamera.CameraInfo recording: NotRequired[str] + model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]] Type = Model3DDict @@ -2298,6 +2310,7 @@ __all__ = [ "FlowControl", "Accumulation", "Load3DCamera", + "Load3DModelInfo", "Load3D", "Load3DAnimation", "Photomaker", diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 9c27c0191..6f05f050e 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -47,6 +47,7 @@ class Load3D(IO.ComfyNode): IO.Load3DCamera.Output(display_name="camera_info"), IO.Video.Output(display_name="recording_video"), IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), ], ) @@ -73,7 +74,8 @@ class Load3D(IO.ComfyNode): if model_file and model_file != "none": file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) mesh_path = model_file - return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d) + model_3d_info = image.get('model_3d_info', []) + return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d, model_3d_info) process = execute # TODO: remove From ea5b09257666374904cbf7a9a7a97ca2edbae43c Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 29 May 2026 19:08:43 +0300 Subject: [PATCH 03/32] [Partner Nodes] fix: removed "beta" models versions from Grok nodes (#14170) --- comfy_api_nodes/nodes_grok.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index 43e3cdc26..a41da42f3 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -58,7 +58,6 @@ class GrokImageNode(IO.ComfyNode): "grok-imagine-image-quality", "grok-imagine-image-pro", "grok-imagine-image", - "grok-imagine-image-beta", ], ), IO.String.Input( @@ -233,7 +232,6 @@ class GrokImageEditNode(IO.ComfyNode): "grok-imagine-image-quality", "grok-imagine-image-pro", "grok-imagine-image", - "grok-imagine-image-beta", ], ), IO.Image.Input("image", display_name="images"), @@ -506,7 +504,7 @@ class GrokVideoNode(IO.ComfyNode): category="video/partner/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video"]), IO.String.Input( "prompt", multiline=True, @@ -576,8 +574,6 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: - if model == "grok-imagine-video-beta": - model = "grok-imagine-video" image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -618,7 +614,7 @@ class GrokVideoEditNode(IO.ComfyNode): category="video/partner/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video"]), IO.String.Input( "prompt", multiline=True, From 54d5be4a8e69749a87280e9a1a9e10d4a7aad3a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Fri, 29 May 2026 19:14:32 +0300 Subject: [PATCH 04/32] Fix background removal mask output shape (#14171) --- comfy/bg_removal_model.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py index 6dec65e63..c772c5f6a 100644 --- a/comfy/bg_removal_model.py +++ b/comfy/bg_removal_model.py @@ -55,12 +55,7 @@ class BackgroundRemovalModel(): out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - if mask.ndim == 3: - mask = mask.unsqueeze(0) - if mask.shape[1] != 1: - mask = mask.movedim(-1, 1) - - return mask + return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W) def load_background_removal_model(sd): From ec1896aceb012697f0bbbc3a941b50f06e030faa Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 29 May 2026 19:19:53 +0300 Subject: [PATCH 05/32] [Partner Nodes] feat: add new nodes for Tripo3D P1 model (#14155) --- comfy_api_nodes/apis/tripo.py | 325 +++++++++++++---------- comfy_api_nodes/nodes_tripo.py | 465 ++++++++++++++++++++++++++++++++- 2 files changed, 633 insertions(+), 157 deletions(-) diff --git a/comfy_api_nodes/apis/tripo.py b/comfy_api_nodes/apis/tripo.py index bce6b0e89..7ac81d42c 100644 --- a/comfy_api_nodes/apis/tripo.py +++ b/comfy_api_nodes/apis/tripo.py @@ -1,25 +1,25 @@ from enum import Enum -from typing import Optional, Any +from typing import Any from pydantic import BaseModel, Field, RootModel class TripoModelVersion(str, Enum): - v3_1_20260211 = 'v3.1-20260211' - v3_0_20250812 = 'v3.0-20250812' - v2_5_20250123 = 'v2.5-20250123' - v2_0_20240919 = 'v2.0-20240919' - v1_4_20240625 = 'v1.4-20240625' + v3_1_20260211 = "v3.1-20260211" + v3_0_20250812 = "v3.0-20250812" + v2_5_20250123 = "v2.5-20250123" + v2_0_20240919 = "v2.0-20240919" + v1_4_20240625 = "v1.4-20240625" class TripoGeometryQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' + standard = "standard" + detailed = "detailed" class TripoTextureQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' + standard = "standard" + detailed = "detailed" class TripoStyle(str, Enum): @@ -33,6 +33,7 @@ class TripoStyle(str, Enum): ANCIENT_BRONZE = "ancient_bronze" NONE = "None" + class TripoTaskType(str, Enum): TEXT_TO_MODEL = "text_to_model" IMAGE_TO_MODEL = "image_to_model" @@ -45,26 +46,27 @@ class TripoTaskType(str, Enum): STYLIZE_MODEL = "stylize_model" CONVERT_MODEL = "convert_model" + class TripoTextureAlignment(str, Enum): ORIGINAL_IMAGE = "original_image" GEOMETRY = "geometry" + class TripoOrientation(str, Enum): ALIGN_IMAGE = "align_image" DEFAULT = "default" + class TripoOutFormat(str, Enum): GLB = "glb" FBX = "fbx" -class TripoTopology(str, Enum): - BIP = "bip" - QUAD = "quad" class TripoSpec(str, Enum): MIXAMO = "mixamo" TRIPO = "tripo" + class TripoAnimation(str, Enum): IDLE = "preset:idle" WALK = "preset:walk" @@ -83,11 +85,6 @@ class TripoAnimation(str, Enum): SERPENTINE_MARCH = "preset:serpentine:march" AQUATIC_MARCH = "preset:aquatic:march" -class TripoStylizeStyle(str, Enum): - LEGO = "lego" - VOXEL = "voxel" - VORONOI = "voronoi" - MINECRAFT = "minecraft" class TripoConvertFormat(str, Enum): GLTF = "GLTF" @@ -97,6 +94,7 @@ class TripoConvertFormat(str, Enum): STL = "STL" _3MF = "3MF" + class TripoTextureFormat(str, Enum): BMP = "BMP" DPX = "DPX" @@ -108,6 +106,7 @@ class TripoTextureFormat(str, Enum): TIFF = "TIFF" WEBP = "WEBP" + class TripoTaskStatus(str, Enum): QUEUED = "queued" RUNNING = "running" @@ -118,183 +117,223 @@ class TripoTaskStatus(str, Enum): BANNED = "banned" EXPIRED = "expired" + class TripoFbxPreset(str, Enum): BLENDER = "blender" MIXAMO = "mixamo" _3DSMAX = "3dsmax" + class TripoFileTokenReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') + type: str | None = Field(None, description="The type of the reference") file_token: str + class TripoUrlReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') + type: str | None = Field(None, description="The type of the reference") url: str + class TripoObjectStorage(BaseModel): bucket: str key: str + class TripoObjectReference(BaseModel): type: str object: TripoObjectStorage + class TripoFileEmptyReference(BaseModel): pass + class TripoFileReference(RootModel): root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference -class TripoGetStsTokenRequest(BaseModel): - format: str = Field(..., description='The format of the image') class TripoTextToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') - prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) - negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) - model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123 - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - image_seed: Optional[int] = Field(None, description='The seed for the text') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - style: Optional[TripoStyle] = None - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description="Type of task") + prompt: str = Field(..., description="The text prompt describing the model to generate", max_length=1024) + negative_prompt: str | None = Field(None, description="The negative text prompt", max_length=1024) + model_version: TripoModelVersion | None = TripoModelVersion.v2_5_20250123 + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + image_seed: int | None = Field(None, description="The seed for the text") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + style: TripoStyle | None = None + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoImageToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task') - file: TripoFileReference = Field(..., description='The file reference to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') - style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description="Type of task") + file: TripoFileReference = Field(..., description="The file reference to convert to a model") + model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation") + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + texture_alignment: TripoTextureAlignment | None = Field( + TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method" + ) + style: TripoStyle | None = Field(None, description="The style to apply to the generated model") + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + orientation: TripoOrientation | None = TripoOrientation.DEFAULT + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoMultiviewToModelRequest(BaseModel): type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL - files: list[TripoFileReference] = Field(..., description='The file references to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + files: list[TripoFileReference] = Field(..., description="The file references to convert to a model") + model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation") + orthographic_projection: bool | None = Field(False, description="Whether to use orthographic projection") + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + texture_alignment: TripoTextureAlignment | None = TripoTextureAlignment.ORIGINAL_IMAGE + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + orientation: TripoOrientation | None = Field(TripoOrientation.DEFAULT, description="The orientation for the model") + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoTextureModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture') - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') + type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + texture: bool | None = Field(True, description="Whether to apply texture to the model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = Field(None, description="The quality of the texture") + texture_alignment: TripoTextureAlignment | None = Field( + TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method" + ) + class TripoRefineModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task') - draft_model_task_id: str = Field(..., description='The task ID of the draft model') + type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description="Type of task") + draft_model_task_id: str = Field(..., description="The task ID of the draft model") -class TripoAnimatePrerigcheckRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') class TripoAnimateRigRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging') + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format") + spec: TripoSpec | None = Field(TripoSpec.TRIPO, description="The specification for rigging") + class TripoAnimateRetargetRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - animation: TripoAnimation = Field(..., description='The animation to apply') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation') + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + animation: TripoAnimation = Field(..., description="The animation to apply") + out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format") + bake_animation: bool | None = Field(True, description="Whether to bake the animation") -class TripoStylizeModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task') - style: TripoStylizeStyle = Field(..., description='The style to apply to the model') - original_model_task_id: str = Field(..., description='The task ID of the original model') - block_size: Optional[int] = Field(80, description='The block size for stylization') class TripoConvertModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') - format: TripoConvertFormat = Field(..., description='The format to convert to') - original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(None, description='The size of the texture') - texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') - scale_factor: Optional[float] = Field(None, description='The scale factor for the model') - with_animation: Optional[bool] = Field(None, description='Whether to include animations') - pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') - bake: Optional[bool] = Field(None, description='Whether to bake the model') - part_names: Optional[list[str]] = Field(None, description='The names of the parts to include') - fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') - export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') - export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') - animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') + type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description="Type of task") + format: TripoConvertFormat = Field(..., description="The format to convert to") + original_model_task_id: str = Field(..., description="The task ID of the original model") + quad: bool | None = Field(None, description="Whether to apply quad to the model") + force_symmetry: bool | None = Field(None, description="Whether to force symmetry") + face_limit: int | None = Field(None, description="The number of faces to limit the conversion to") + flatten_bottom: bool | None = Field(None, description="Whether to flatten the bottom of the model") + flatten_bottom_threshold: float | None = Field(None, description="The threshold for flattening the bottom") + texture_size: int | None = Field(None, description="The size of the texture") + texture_format: TripoTextureFormat | None = Field(TripoTextureFormat.JPEG, description="The format of the texture") + pivot_to_center_bottom: bool | None = Field(None, description="Whether to pivot to the center bottom") + scale_factor: float | None = Field(None, description="The scale factor for the model") + with_animation: bool | None = Field(None, description="Whether to include animations") + pack_uv: bool | None = Field(None, description="Whether to pack the UVs") + bake: bool | None = Field(None, description="Whether to bake the model") + part_names: list[str] | None = Field(None, description="The names of the parts to include") + fbx_preset: TripoFbxPreset | None = Field(None, description="The preset for the FBX export") + export_vertex_colors: bool | None = Field(None, description="Whether to export the vertex colors") + export_orientation: TripoOrientation | None = Field(None, description="The orientation for the export") + animate_in_place: bool | None = Field(None, description="Whether to animate in place") + + +class TripoP1CommonRequest(BaseModel): + """Fields supported by Tripo P1 across all input types.""" + + model_version: str = Field("P1-20260311") + model_seed: int | None = Field(None, description="Random seed for geometry generation") + face_limit: int | None = Field(None, ge=48, le=20000, description="Target face count (48-20000)") + texture: bool | None = Field(None, description="Enable texturing; pbr=True forces this true") + pbr: bool | None = Field(None, description="Enable PBR maps; when true, texture is also enabled") + texture_seed: int | None = Field(None, description="Random seed for texture generation") + texture_quality: str | None = Field(None, description='"standard" or "detailed"') + auto_size: bool | None = Field(None, description="Scale to real-world meters") + compress: str | None = Field(None, description='Only "geometry" is supported') + export_uv: bool | None = Field(None, description="Perform UV unwrapping during generation") + + +class TripoP1TextToModelRequest(TripoP1CommonRequest): + type: str = "text_to_model" + prompt: str = Field(..., max_length=1024) + negative_prompt: str | None = Field(None, max_length=255) + image_seed: int | None = None + + +class TripoP1ImageToModelRequest(TripoP1CommonRequest): + type: str = "image_to_model" + file: TripoFileReference + enable_image_autofix: bool | None = None + texture_alignment: str | None = Field(None, description='"original_image" or "geometry"') + orientation: str | None = Field(None, description='"default" or "align_image"; needs texture=true') + + +class TripoP1MultiviewToModelRequest(TripoP1CommonRequest): + """P1 multiview generation. + + Tripo requires `files` to be exactly four entries in [front, left, back, right] order with `{}` + (TripoFileEmptyReference) for omitted slots; front is required and at least two images total must be provided. + """ + + type: str = "multiview_to_model" + files: list[TripoFileReference] + texture_alignment: str | None = None + orientation: str | None = None class TripoTaskOutput(BaseModel): - model: Optional[str] = Field(None, description='URL to the model') - base_model: Optional[str] = Field(None, description='URL to the base model') - pbr_model: Optional[str] = Field(None, description='URL to the PBR model') - rendered_image: Optional[str] = Field(None, description='URL to the rendered image') - riggable: Optional[bool] = Field(None, description='Whether the model is riggable') + model: str | None = Field(None, description="URL to the model") + base_model: str | None = Field(None, description="URL to the base model") + pbr_model: str | None = Field(None, description="URL to the PBR model") + rendered_image: str | None = Field(None, description="URL to the rendered image") + riggable: bool | None = Field(None, description="Whether the model is riggable") + class TripoTask(BaseModel): - task_id: str = Field(..., description='The task ID') - type: Optional[str] = Field(None, description='The type of task') - status: Optional[TripoTaskStatus] = Field(None, description='The status of the task') - input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task') - output: Optional[TripoTaskOutput] = Field(None, description='The output of the task') - progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100) - create_time: Optional[int] = Field(None, description='The creation time of the task') - running_left_time: Optional[int] = Field(None, description='The estimated time left for the task') - queue_position: Optional[int] = Field(None, description='The position in the queue') + task_id: str = Field(..., description="The task ID") + type: str | None = Field(None, description="The type of task") + status: TripoTaskStatus | None = Field(None, description="The status of the task") + input: dict[str, Any] | None = Field(None, description="The input parameters for the task") + output: TripoTaskOutput | None = Field(None, description="The output of the task") + progress: int | None = Field(None, description="The progress of the task", ge=0, le=100) + create_time: int | None = Field(None, description="The creation time of the task") + running_left_time: int | None = Field(None, description="The estimated time left for the task") + queue_position: int | None = Field(None, description="The position in the queue") consumed_credit: int | None = Field(None) + class TripoTaskResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoTask = Field(..., description='The task data') + code: int = Field(0, description="The response code") + data: TripoTask = Field(..., description="The task data") -class TripoGeneralResponse(BaseModel): - code: int = Field(0, description='The response code') - data: dict[str, str] = Field(..., description='The task ID data') - -class TripoBalanceData(BaseModel): - balance: float = Field(..., description='The account balance') - frozen: float = Field(..., description='The frozen balance') - -class TripoBalanceResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoBalanceData = Field(..., description='The balance data') class TripoErrorResponse(BaseModel): - code: int = Field(..., description='The error code') - message: str = Field(..., description='The error message') - suggestion: str = Field(..., description='The suggestion for fixing the error') + code: int = Field(..., description="The error code") + message: str = Field(..., description="The error message") + suggestion: str = Field(..., description="The suggestion for fixing the error") diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 6ee674a18..4820e26c1 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -11,6 +11,9 @@ from comfy_api_nodes.apis.tripo import ( TripoModelVersion, TripoMultiviewToModelRequest, TripoOrientation, + TripoP1ImageToModelRequest, + TripoP1MultiviewToModelRequest, + TripoP1TextToModelRequest, TripoRefineModelRequest, TripoStyle, TripoTaskResponse, @@ -93,10 +96,22 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Int.Input("image_seed", default=42, optional=True, advanced=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True), IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -209,16 +224,36 @@ class TripoImageToModelNode(IO.ComfyNode): IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Combo.Input( - "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + advanced=True, ), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -346,13 +381,35 @@ class TripoMultiviewToModelNode(IO.ComfyNode): IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Boolean.Input( + "quad", + default=False, + optional=True, + advanced=True, + tooltip="This parameter is deprecated and does nothing.", + ), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -467,9 +524,19 @@ class TripoTextureNode(IO.ComfyNode): IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), ], outputs=[ @@ -626,7 +693,7 @@ class TripoRetargetNode(IO.ComfyNode): "preset:hexapod:walk", "preset:octopod:walk", "preset:serpentine:march", - "preset:aquatic:march" + "preset:aquatic:march", ], ), ], @@ -817,7 +884,7 @@ class TripoConversionNode(IO.ComfyNode): # Parse part_names from comma-separated string to list part_names_list = None if part_names and part_names.strip(): - part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] + part_names_list = [name.strip() for name in part_names.split(",") if name.strip()] response = await sync_op( cls, @@ -848,6 +915,373 @@ class TripoConversionNode(IO.ComfyNode): return await poll_until_finished(cls, response, average_duration=30) +def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str: + return ( + "(" + " $mode := widgets.output_mode;" + ' $detailed := $lookup(widgets, "output_mode.texture_quality") = "detailed";' + f' $credits := $mode = "geometry only" ? {geometry_credits} : ($detailed ? {detailed_credits} : {textured_credits});' + ' {"type":"usd","usd": $credits * 0.01, "format": {"approximate": true}}' + ")" + ) + + +def _p1_textured_inputs(*, include_image_alignment: bool) -> list: + """Inputs shown inside the 'Textured' branch of the P1 output_mode DynamicCombo.""" + inputs: list = [ + IO.Boolean.Input("pbr", default=True, tooltip="Include PBR maps. When on, base texture is forced on too."), + IO.Combo.Input("texture_quality", options=["standard", "detailed"], default="standard"), + ] + if include_image_alignment: + inputs.extend( + [ + IO.Combo.Input( + "texture_alignment", + options=["original_image", "geometry"], + default="original_image", + tooltip="Prioritize visual fidelity to the source image, or alignment to the mesh geometry.", + ), + IO.Combo.Input( + "orientation", + options=["default", "align_image"], + default="default", + tooltip="Rotate the output to match the source image. Only applies when textured.", + ), + ] + ) + inputs.append(IO.Int.Input("texture_seed", default=42, advanced=True)) + return inputs + + +def _build_p1_output_mode(*, include_image_alignment: bool) -> IO.DynamicCombo.Input: + return IO.DynamicCombo.Input( + "output_mode", + options=[ + IO.DynamicCombo.Option("Geometry only", []), + IO.DynamicCombo.Option("Textured", _p1_textured_inputs(include_image_alignment=include_image_alignment)), + ], + tooltip='"Geometry only" returns an untextured mesh. "Textured" adds color/PBR maps.', + ) + + +def _resolve_p1_texture_fields(output_mode: dict) -> dict: + """Translate the output_mode DynamicCombo payload into P1 request fields. + + pbr=true forces texture=true server-side, but we send both explicitly so the + intent is visible in the request body and logs. + """ + mode = output_mode["output_mode"] + if mode == "Geometry only": + return {"texture": False, "pbr": False} + out = { + "texture": True, + "pbr": bool(output_mode.get("pbr", True)), + "texture_quality": output_mode.get("texture_quality", "standard"), + "texture_seed": output_mode.get("texture_seed"), + } + if "texture_alignment" in output_mode: + out["texture_alignment"] = output_mode["texture_alignment"] + if "orientation" in output_mode: + out["orientation"] = output_mode["orientation"] + return out + + +def _p1_common_inputs() -> list: + """Inputs shared by all P1 nodes (placed after output_mode).""" + return [ + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=20000, + optional=True, + advanced=True, + tooltip="Target face count, 48-20000. -1 lets Tripo pick adaptively.", + ), + IO.Int.Input("model_seed", default=42, optional=True, advanced=True), + IO.Boolean.Input( + "auto_size", + default=False, + optional=True, + advanced=True, + tooltip="Scale the output to approximate real-world meters.", + ), + IO.Boolean.Input( + "export_uv", + default=True, + optional=True, + advanced=True, + tooltip="UV unwrap during generation. Turn off for faster geometry-only runs.", + ), + IO.Boolean.Input( + "compress_geometry", + default=False, + optional=True, + advanced=True, + tooltip="Apply geometry-based compression. Decompress before editing.", + ), + ] + + +def _build_p1_request_kwargs( + *, + output_mode: dict, + face_limit: int, + model_seed: int, + auto_size: bool, + export_uv: bool, + compress_geometry: bool, +) -> dict: + """Common P1 request fields shared by all three node types.""" + kwargs: dict = { + "model_seed": model_seed, + "face_limit": face_limit if face_limit != -1 else None, + "auto_size": auto_size, + "export_uv": export_uv, + "compress": "geometry" if compress_geometry else None, + } + kwargs.update(_resolve_p1_texture_fields(output_mode)) + return kwargs + + +class TripoP1TextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1TextToModelNode", + display_name="Tripo P1: Text to Model", + category="3d/partner/Tripo", + description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.", + inputs=[ + IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."), + IO.String.Input("negative_prompt", multiline=True, optional=True, tooltip="Up to 255 characters."), + _build_p1_output_mode(include_image_alignment=False), + IO.Int.Input("image_seed", default=42, optional=True, advanced=True), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=30, textured_credits=40, detailed_credits=50), + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + output_mode: dict, + negative_prompt: str | None = None, + image_seed: int | None = None, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + if not prompt: + raise RuntimeError("Prompt is required") + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1TextToModelRequest( + prompt=prompt, + negative_prompt=negative_prompt or None, + image_seed=image_seed, + **common, + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=60) + + +class TripoP1ImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1ImageToModelNode", + display_name="Tripo P1: Image to Model", + category="3d/partner/Tripo", + description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.", + inputs=[ + IO.Image.Input("image"), + _build_p1_output_mode(include_image_alignment=True), + IO.Boolean.Input( + "enable_image_autofix", + default=False, + optional=True, + advanced=True, + tooltip="Pre-process the input image for better generation quality.", + ), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60), + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + output_mode: dict, + enable_image_autofix: bool = False, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + if image is None: + raise RuntimeError("Image is required") + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1ImageToModelRequest( + file=tripo_file, + enable_image_autofix=enable_image_autofix, + **common, + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=60) + + +class TripoP1MultiviewToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1MultiviewToModelNode", + display_name="Tripo P1: Multiview to Model", + category="3d/partner/Tripo", + description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. " + "Front is required; any combination of the other three may be omitted.", + inputs=[ + IO.Image.Input("image", tooltip="Front view (0°). Required."), + IO.Image.Input( + "image_left", + optional=True, + tooltip="Left view (90°), i.e. the subject's left side.", + ), + IO.Image.Input("image_back", optional=True, tooltip="Back view (180°)."), + IO.Image.Input( + "image_right", + optional=True, + tooltip="Right view (270°), i.e. the subject's right side.", + ), + _build_p1_output_mode(include_image_alignment=True), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60), + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + output_mode: dict, + image_left: Input.Image | None = None, + image_back: Input.Image | None = None, + image_right: Input.Image | None = None, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + views = [image, image_left, image_back, image_right] + if sum(1 for v in views if v is not None) < 2: + raise RuntimeError("Tripo P1 multiview requires at least 2 images (front plus one of left/back/right).") + + files: list[TripoFileReference] = [] + for view in views: + if view is None: + files.append(TripoFileReference(root=TripoFileEmptyReference())) + continue + url = (await upload_images_to_comfyapi(cls, view, max_images=1))[0] + files.append(TripoFileReference(root=TripoUrlReference(url=url, type="jpeg"))) + + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1MultiviewToModelRequest(files=files, **common) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=80) + + class TripoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -855,6 +1289,9 @@ class TripoExtension(ComfyExtension): TripoTextToModelNode, TripoImageToModelNode, TripoMultiviewToModelNode, + TripoP1TextToModelNode, + TripoP1ImageToModelNode, + TripoP1MultiviewToModelNode, TripoTextureNode, TripoRefineNode, TripoRigNode, From 6e1ef2311ba73e68330e4041b34cdfd9e8fb6aa2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 29 May 2026 16:26:46 -0700 Subject: [PATCH 06/32] Remove useless code. (#14178) --- comfy/ldm/qwen_image/model.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 0862f72f7..3462d8108 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -51,15 +51,6 @@ class FeedForward(nn.Module): return hidden_states -def apply_rotary_emb(x, freqs_cis): - if x.shape[1] == 0: - return x - - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x.shape) - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() From 0b04660ba329ce643f29cabf4905aa9b5a71c6f8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 29 May 2026 22:47:10 -0700 Subject: [PATCH 07/32] Speed up anima a bit on nvidia. (#14181) --- comfy/ldm/cosmos/predict2.py | 12 +----------- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..30a36ad49 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -15,15 +15,6 @@ import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention import comfy.ldm.common_dit -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - # ---------------------- Feed Forward Network ----------------------- class GPT2FeedForward(nn.Module): @@ -173,8 +164,7 @@ class Attention(nn.Module): k = self.k_norm(k) v = self.v_norm(v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) + q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb) return q, k, v q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) diff --git a/requirements.txt b/requirements.txt index 0617667e1..7ae16eb5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy>=2.0.0 filelock av>=16.0.0 -comfy-kitchen==0.2.9 +comfy-kitchen==0.2.10 comfy-aimdo==0.4.5 requests simpleeval>=1.0.0 From bb560036b9631d0049d51abd388be262bb57a152 Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Sat, 30 May 2026 09:39:26 -0400 Subject: [PATCH 08/32] feat(io): add File3DPLY / File3DSPLAT / File3DSPZ / File3DKSPLAT types (#14185) --- comfy_api/latest/_io.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 19d8176b0..e03bafcde 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -727,6 +727,30 @@ class File3DUSDZ(ComfyTypeIO): Type = File3D +@comfytype(io_type="FILE_3D_PLY") +class File3DPLY(ComfyTypeIO): + """PLY format 3D file - point cloud or Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_SPLAT") +class File3DSPLAT(ComfyTypeIO): + """SPLAT format 3D file - 3D Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_SPZ") +class File3DSPZ(ComfyTypeIO): + """SPZ format 3D file - compressed 3D Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_KSPLAT") +class File3DKSPLAT(ComfyTypeIO): + """KSPLAT format 3D file - 3D Gaussian splat.""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -2303,6 +2327,10 @@ __all__ = [ "File3DOBJ", "File3DSTL", "File3DUSDZ", + "File3DPLY", + "File3DSPLAT", + "File3DSPZ", + "File3DKSPLAT", "Hooks", "HookKeyframes", "TimestepsRange", From e154da83b135c0f2ae6d73a2c177566016ea2158 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 31 May 2026 05:20:04 +1000 Subject: [PATCH 09/32] Threaded Loader performance fixes / improvements (+ Aimdo 0.4.6) (#14116) * memory_management: Add direct to read GPU mode Make destination optional (or make it optionally GPU) and use aimdo to file_read direct to GPU. * ops: Remove stream pin buffers and use aimdo reads This consumed too much RAM and its better to just take the hit on the CPU syncing back the stream on a short ring buffer. Aimdo implements this so just rip the stream pin buffer from comfy. * model_management: all active pin registration movement Its better to just let the active model load past the pin limit as pins and let the pins move around. The saves the HDD and SATA people disk traffic while only costing a few GPU syncs. * utils: use aimdo file handle This opens on windows with more favourable flags * mp: only count the model proper for loaded_ram and vram Exclude live loras from the numbers to avoid the case where the reported loaded memory exceeds the size of the model. This causes me confusion in the Kijai visualizer when it looked fully loaded but was hitting disk due to this accounding disrepency. * utils: add bit reverse utility useful for max scattering something ordered. * pinned_memory: Implement offload balancing Use a max scatter alogorithm to prioritize pins of the same size such that when doing a little bit of offloading it gets scattered, allowing the prefetcher to more evenly swollow the offload. * comfy-aimdo 0.4.7 Aimdo 0.4.7 implement VRAM buffer exhaustion predection to avoid early speculative load of weights that definately wont fix once the inference gets further in. * model-prefetch: consolidate pin ensures on the sync point This could happen mid prefetch block, cause a sync of the entire block and lose overlap. Get ahead of the problem with a free down at the natural compute stream sync point. * mm: Put a 2GB min on the pin ceiling This is reasonably bad if it starts causing swap pressure, moreso than during normal ram-cache proceedings. Clamp it. * add --fast-disk --- comfy/cli_args.py | 1 + comfy/memory_management.py | 40 ++++++++++++++------- comfy/model_management.py | 73 +++++++++++++++----------------------- comfy/model_patcher.py | 20 +++++------ comfy/model_prefetch.py | 11 ++++++ comfy/ops.py | 66 ++++++---------------------------- comfy/pinned_memory.py | 59 +++++++++++++++++++++++++----- comfy/utils.py | 9 ++++- requirements.txt | 2 +- 9 files changed, 149 insertions(+), 132 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9bda414d1..a4cabcc65 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.") +parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 962addb27..e032b7dcd 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -4,6 +4,7 @@ import dataclasses import torch from typing import NamedTuple +import comfy_aimdo.host_buffer from comfy.quant_ops import QuantizedTensor @@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple): def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): - if not isinstance(destination, QuantizedTensor): - return False - if tensor._layout_cls != destination._layout_cls: - return False - - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + if not read_tensor_file_slice_into(tensor._qdata, + destination._qdata if destination is not None else None, stream=stream, destination2=(destination2._qdata if destination2 is not None else None)): return False - dst_orig_dtype = destination._params.orig_dtype - destination._params.copy_from(tensor._params, non_blocking=False) - destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination is not None: + dst_orig_dtype = destination._params.orig_dtype + destination._params.copy_from(tensor._params, non_blocking=False) + destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) if destination2 is not None: dst_orig_dtype = destination2._params.orig_dtype - destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True) destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True @@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if info is None: return False + if destination is not None and destination.device.type != "cpu" and destination2 is None: + destination2 = destination + destination = None + file_obj = info.file_ref - if (destination.device.type != "cpu" - or file_obj is None - or destination.numel() * destination.element_size() < info.size + if (file_obj is None + or (destination is None and destination2 is None) + or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size)) + or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size)) or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 or not tensor.is_contiguous()): @@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if info.size == 0: return True + if destination is None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size, + stream_ptr, destination2.data_ptr(), + destination2.device.index, + mark_cold=False) + return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 @@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N device=None if destination2 is None else destination2.device.index) return True + if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"): + return False + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) diff --git a/comfy/model_management.py b/comfy/model_management.py index b01c4d7fa..c264efc2d 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -641,14 +641,17 @@ def free_pins(size, evict_active=False): return freed_total def ensure_pin_budget(size, evict_active=False): - shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if args.fast_disk: + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + else: + shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available if shortfall <= 0: return True to_free = shortfall + PIN_PRESSURE_HYSTERESIS return free_pins(to_free, evict_active=evict_active) >= shortfall -def ensure_pin_registerable(size, evict_active=False): +def ensure_pin_registerable(size, evict_active=True): shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY if MAX_PINNED_MEMORY <= 0: return False @@ -658,10 +661,17 @@ def ensure_pin_registerable(size, evict_active=False): shortfall += REGISTERABLE_PIN_HYSTERESIS for loaded_model in reversed(current_loaded_models): model = loaded_model.model - if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]: shortfall -= model.unregister_inactive_pins(shortfall) if shortfall <= 0: return True + if evict_active: + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]: + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: @@ -1283,7 +1293,6 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) -STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1326,42 +1335,13 @@ def get_aimdo_cast_buffer(offload_stream, device): STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer return cast_buffer -def get_pin_buffer(offload_stream): - pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) - if pin_buffer is None: - pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False) - STREAM_PIN_BUFFERS[offload_stream] = pin_buffer - elif offload_stream is not None: - event = getattr(pin_buffer, "_comfy_event", None) - if event is not None: - event.synchronize() - delattr(pin_buffer, "_comfy_event") - return pin_buffer - -def resize_pin_buffer(pin_buffer, size): - global TOTAL_PINNED_MEMORY - old_size = pin_buffer.size - if size <= old_size: - return True - growth = size - old_size - comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) - ensure_pin_budget(growth, evict_active=True) - ensure_pin_registerable(growth, evict_active=True) - try: - pin_buffer.extend(size=size, reallocate=True) - except RuntimeError: - return False - TOTAL_PINNED_MEMORY += pin_buffer.size - old_size - return True - def reset_cast_buffers(): - global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() @@ -1370,20 +1350,24 @@ def reset_cast_buffers(): mmap_obj.bounce() DIRTY_MMAPS.clear() - for pin_buffer in STREAM_PIN_BUFFERS.values(): - TOTAL_PINNED_MEMORY -= pin_buffer.size - TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) - for loaded_model in current_loaded_models: model = loaded_model.model if model is not None and model.is_dynamic(): - model.model.dynamic_pins[model.load_device]["active"] = False + pin_state = model.model.dynamic_pins[model.load_device] + + if pin_state["active"]: + *_, buckets = pin_state["weights"] + for size, bucket in list(buckets.items()): + bucket[:] = [ entry for entry in bucket if entry[-1] is not None ] + if not bucket: + del buckets[size] + + pin_state["active"] = False model.partially_unload_ram(1e30, subsets=[ "patches" ]) - model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {}) STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() - STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1436,7 +1420,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) - dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors) dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: @@ -1448,9 +1432,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() mark_mmap_dirty(storage) - dest_view.copy_(tensor, non_blocking=non_blocking) + if dest_view is not None: + dest_view.copy_(tensor, non_blocking=non_blocking) if dest2_view is not None: - dest2_view.copy_(dest_view, non_blocking=non_blocking) + dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00a15fa63..b716a69e2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher): """ if device not in self.model.dynamic_pins: self.model.dynamic_pins[device] = { - "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), - "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), "hostbufs_initialized": False, "failed": False, "active": False, @@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher): pin_state = self.model.dynamic_pins[self.load_device] if not pin_state["hostbufs_initialized"]: hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) - pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) - pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) pin_state["hostbufs_initialized"] = True pin_state["failed"] = False pin_state["active"] = True @@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher): return freed def loaded_ram_size(self): - return (self.model.dynamic_pins[self.load_device]["weights"][0].size + - self.model.dynamic_pins[self.load_device]["patches"][0].size) + return (self.model.dynamic_pins[self.load_device]["weights"][0].size) def pinned_memory_size(self): - return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + - self.model.dynamic_pins[self.load_device]["patches"][3][0]) + return (self.model.dynamic_pins[self.load_device]["weights"][3][0]) def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): freed = 0 pin_state = self.model.dynamic_pins[self.load_device] for subset in subsets: - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] split = stack_split[0] while split >= 0: module, offset = stack[split] @@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher): freed = 0 pin_state = self.model.dynamic_pins[self.load_device] for subset in subsets: - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] while len(stack) > 0: module, offset = stack.pop() size = module._pin.numel() * module._pin.element_size() + module._pin_balancer_entry[-1] = None + del module._pin_balancer_entry del module._pin hostbuf.truncate(offset, do_unregister=module._pin_registered) stack_split[0] = min(stack_split[0], len(stack) - 1) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index 72e11dec6..aa6d22d77 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -1,4 +1,5 @@ import comfy_aimdo.model_vbar +import comfy.memory_management import comfy.model_management import comfy.ops @@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module): if hasattr(s, "_v"): comfy_modules.append(s) + registerable_size = 0 + for s in comfy_modules: + registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias]) + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + registerable_size += lowvram_fn.memory_required() + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + if not comfy.model_management.args.fast_disk: + comfy.model_management.ensure_pin_registerable(registerable_size) comfy.model_management.sync_stream(device, offload_stream) queue[0] = (offload_stream, (prefetch, comfy_modules)) diff --git a/comfy/ops.py b/comfy/ops.py index 56445be8d..119177c37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -76,8 +76,6 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references -STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 - stream_pin_hostbuf = None - stream_pin_offset = 0 - stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer - def get_stream_pin_buffer_offset(buffer_size): - nonlocal stream_pin_hostbuf - nonlocal stream_pin_offset - - if buffer_size == 0 or offload_stream is None: - return None - - if stream_pin_hostbuf is None: - stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) - if stream_pin_hostbuf is None: - return None - - offset = stream_pin_offset - stream_pin_offset += buffer_size - return offset - for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None): if xfer_source is not None: if getattr(xfer_source, "is_lowvram_patch", False): - xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) - else: - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) + if xfer_dest is not None: + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + xfer_source = [ xfer_dest ] + xfer_dest = xfer_dest2 + xfer_dest2 = None + elif xfer_dest2 is not None: + xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False) + return + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2) def handle_pin(m, pin, source, dest, subset="weights", size=None): if pin is not None: @@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if signature is None: comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) - if pin is not None: - if isinstance(source, list): - comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) - else: - cast_maybe_lowvram_patch(source, pin, None) - cast_maybe_lowvram_patch([ pin ], dest, offload_stream) - return - if pin is None: - pin_offset = get_stream_pin_buffer_offset(size) - if pin_offset is not None: - stream_pin_queue.append((source, pin_offset, size, dest)) - return - cast_maybe_lowvram_patch(source, dest, offload_stream) + cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest) handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) @@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch - if stream_pin_offset > 0: - if stream_pin_hostbuf.size < stream_pin_offset: - if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): - for xfer_source, _, _, xfer_dest in stream_pin_queue: - cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) - return offload_stream - stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) - stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf - for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: - pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] - if isinstance(xfer_source, list): - comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) - else: - cast_maybe_lowvram_patch(xfer_source, pin, None) - comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) - stream_pin_hostbuf._comfy_event = offload_stream.record_event() - return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 0e8f573ba..ffe12e0dc 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,17 +1,55 @@ +import bisect + import comfy.model_management import comfy.memory_management +import comfy.utils import comfy_aimdo.host_buffer import comfy_aimdo.torch import torch from comfy.cli_args import args +def _add_to_bucket(module, buckets, size, priority): + bucket = buckets.setdefault(size, []) + entry = [-priority, 0, module] + entry[1] = id(entry) + bisect.insort(bucket, entry) + module._pin_balancer_entry = entry + +def _steal_pin(module, stack, buckets, size, priority): + bucket = buckets.get(size) + if bucket is None: + return False + + while bucket and bucket[-1][-1] is None: + bucket.pop() + if not bucket: + del buckets[size] + return False + + if priority <= -bucket[-1][0]: + return False + + *_, victim = bucket.pop() + module._pin = victim._pin + module._pin_registered = victim._pin_registered + module._pin_stack_index = victim._pin_stack_index + stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1]) + + victim._pin_registered = False + del victim._pin + del victim._pin_stack_index + del victim._pin_balancer_entry + + _add_to_bucket(module, buckets, size, priority) + return True + def get_pin(module, subset="weights"): pin = getattr(module, "_pin", None) if pin is None or module._pin_registered or args.disable_pinned_memory: return pin - _, _, stack_split, pinned_size = module._pin_state[subset] + _, _, stack_split, pinned_size, *_ = module._pin_state[subset] size = pin.nbytes comfy.model_management.ensure_pin_registerable(size) @@ -31,26 +69,30 @@ def pin_memory(module, subset="weights", size=None): return pin = get_pin(module, subset) - if pin is not None or pin_state["failed"]: + if pin is not None: return - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset] if size is None: size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) offset = hostbuf.size - registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + registerable_size = size + priority = getattr(module, "_pin_balancer_priority", None) + + if priority is None: + priority = comfy.utils.bit_reverse_range(counter[0], 16) + counter[0] += 1 + module._pin_balancer_priority = priority comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) if (not comfy.model_management.ensure_pin_budget(size) or not comfy.model_management.ensure_pin_registerable(registerable_size)): - pin_state["failed"] = True - return False + return _steal_pin(module, stack, buckets, size, priority) try: hostbuf.extend(size=size) except RuntimeError: - pin_state["failed"] = True - return False + return _steal_pin(module, stack, buckets, size, priority) module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] module._pin.untyped_storage()._comfy_hostbuf = hostbuf @@ -60,4 +102,5 @@ def pin_memory(module, subset="weights", size=None): stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size pinned_size[0] += size + _add_to_bucket(module, buckets, size, priority) return True diff --git a/comfy/utils.py b/comfy/utils.py index 49ae12b06..09d783fff 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -85,9 +85,9 @@ _TYPES = { def load_safetensors(ckpt): import comfy_aimdo.model_mmap - f = open(ckpt, "rb", buffering=0) file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) + f = model_mmap.get_file_handle() file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res + +def bit_reverse_range(index, bits): + result = 0 + for _ in range(bits): + result = (result << 1) | (index & 1) + index >>= 1 + return result diff --git a/requirements.txt b/requirements.txt index 7ae16eb5b..b2945c3cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=16.0.0 comfy-kitchen==0.2.10 -comfy-aimdo==0.4.5 +comfy-aimdo==0.4.7 requests simpleeval>=1.0.0 blake3 From f7297bc5a9a5c9603a2926791c090bba4962d1cb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 31 May 2026 05:20:33 +1000 Subject: [PATCH 10/32] Revert deprecation of non-dynamic smart memory (CORE-152 (revert)) (#14183) * mm: re-instantate smart memory for VRAM * mm: restore non-dynamic smart memory By popular demand. We aren't quite ready for the deprecation as non dynamic enabled GPUs and some high-vram custom model loader setups prefer the old full hands on. --- comfy/model_management.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c264efc2d..8fb1d7fbc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -813,9 +813,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): + if not DISABLE_SMART_MEMORY or device is None: memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - if for_dynamic: + if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -827,6 +827,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) + if not for_dynamic and pins_required > 0: + ensure_pin_budget(pins_required) + ensure_pin_registerable(pins_required) + if len(unloaded_model) > 0: soft_empty_cache() elif device is not None: @@ -889,15 +893,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model_finalizer.detach() total_memory_required = {} + total_pins_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) + if not loaded_model.model.is_dynamic(): + total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory() for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic) + for_dynamic=free_for_dynamic, + pins_required=total_pins_required.get(device, 0)) for device in total_memory_required: if device != torch.device("cpu"): From 08e93a31a3120172ee31755ba70dda4f1957d8cc Mon Sep 17 00:00:00 2001 From: Terry Jia Date: Sat, 30 May 2026 17:57:36 -0400 Subject: [PATCH 11/32] feat: add Preview3DAdvanced node (#14175) Co-authored-by: Alexis Rolland --- comfy_api/latest/_ui.py | 11 +++++++ comfy_extras/nodes_load_3d.py | 59 +++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index e238cdf3c..6592f6b1d 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -452,6 +452,16 @@ class PreviewUI3D(_UIOutput): return {"result": [self.model_file, self.camera_info, self.bg_image_path]} +class PreviewUI3DAdvanced(_UIOutput): + def __init__(self, model_file, camera_info, model_3d_info): + self.model_file = model_file + self.camera_info = camera_info + self.model_3d_info = model_3d_info + + def as_dict(self): + return {"result": [self.model_file, self.camera_info, self.model_3d_info]} + + class PreviewText(_UIOutput): def __init__(self, value: str, **kwargs): self.value = value @@ -471,5 +481,6 @@ __all__ = [ "PreviewAudio", "PreviewVideo", "PreviewUI3D", + "PreviewUI3DAdvanced", "PreviewText", ] diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 6f05f050e..b339dc4ff 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -124,12 +124,71 @@ class Preview3D(IO.ComfyNode): process = execute # TODO: remove +class Preview3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Preview3DAdvanced", + display_name="Preview 3D (Advanced)", + search_aliases=["preview 3d", "3d viewer", "view mesh", "frame 3d", "3d camera output"], + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.MultiType.Input( + "model_file", + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DFBX, + IO.File3DOBJ, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="3D model file from an upstream 3D node.", + ), + IO.Load3D.Input("image"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_file"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}" + model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + return IO.NodeOutput( + model_file, + camera_info, + model_3d_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, Preview3D, + Preview3DAdvanced, ] From ea73d3b2ea581cdd87de5ef5b684fef7e3e69ffa Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Sun, 31 May 2026 07:49:59 +0800 Subject: [PATCH 12/32] chore: update embedded docs to v0.5.2 (#14193) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b2945c3cf..14bba1437 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.44.19 comfyui-workflow-templates==0.9.91 -comfyui-embedded-docs==0.5.1 +comfyui-embedded-docs==0.5.2 torch torchsde torchvision From 81aa5a38b25a59dadb8b2765e08e277e044351a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 30 May 2026 17:53:37 -0700 Subject: [PATCH 13/32] Speed up ernie model by a bit on nvidia and use higher quality rope. (#14192) --- comfy/ldm/cosmos/predict2.py | 1 + comfy/ldm/ernie/model.py | 27 +++++++++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 30a36ad49..671fe834d 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -14,6 +14,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention import comfy.ldm.common_dit +import comfy.quant_ops # ---------------------- Feed Forward Network ----------------------- diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py index eba661aec..f158ca1d2 100644 --- a/comfy/ldm/ernie/model.py +++ b/comfy/ldm/ernie/model.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.quant_ops def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 @@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) return out.to(dtype=torch.float32, device=pos.device) -def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - rot_dim = freqs_cis.shape[-1] - x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] - cos_ = freqs_cis[0] - sin_ = freqs_cis[1] - x1, x2 = x.chunk(2, dim=-1) - x_rotated = torch.cat((-x2, x1), dim=-1) - return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) - class ErnieImageEmbedND3(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: tuple): super().__init__() @@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module): def forward(self, ids: torch.Tensor) -> torch.Tensor: emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) - emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] - return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + cos_ = emb[0] + sin_ = emb[1] + N = cos_.shape[-1] + half = N // 2 + cos_top = cos_[..., :half].repeat_interleave(2, dim=-1) + sin_top = sin_[..., :half].repeat_interleave(2, dim=-1) + cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1) + sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1) + rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1) + return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2) class ErnieImagePatchEmbedDynamic(nn.Module): def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): @@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module): key = self.norm_k(key) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb) q_flat = query.reshape(B, S, -1) k_flat = key.reshape(B, S, -1) @@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module): image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) - rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) del image_ids, text_ids sample = self.time_proj(timesteps).to(dtype) From cd45f42a83c75403e3b69f85789a2d3af43ced66 Mon Sep 17 00:00:00 2001 From: savvadesogle Date: Sun, 31 May 2026 04:18:42 +0300 Subject: [PATCH 14/32] fix(multigpu): replace hardcoded torch.cuda.set_device with device-agnostic set_torch_device (#14191) --- comfy/model_management.py | 7 +++++++ comfy/multigpu.py | 4 ++-- comfy/samplers.py | 5 +---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8fb1d7fbc..dfd58bf1b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1716,6 +1716,13 @@ def is_device_xpu(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def set_torch_device(device): + """Set the current device for the given torch device. Supports CUDA and XPU.""" + if is_device_cuda(device): + torch.cuda.set_device(device) + elif is_device_xpu(device): + torch.xpu.set_device(device) + def is_directml_enabled(): global directml_enabled if directml_enabled: diff --git a/comfy/multigpu.py b/comfy/multigpu.py index e7f5b3d6f..bb9d334d3 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -17,7 +17,7 @@ class MultiGPUThreadPool: """Persistent thread pool for multi-GPU work distribution. Maintains one worker thread per extra GPU device. Each thread calls - torch.cuda.set_device() once at startup so that compiled kernel caches + set_torch_device() once at startup so that compiled kernel caches (inductor/triton) stay warm across diffusion steps. """ @@ -37,7 +37,7 @@ class MultiGPUThreadPool: def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): try: - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) except Exception as e: logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") while True: diff --git a/comfy/samplers.py b/comfy/samplers.py index e31277f7b..25c5a855f 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: - # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once - # we extend multigpu QA beyond CUDA. Unconditional call crashes on - # XPU/NPU/MPS/CPU/DirectML backends. - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately with torch.no_grad(): From c37d2a0dacaa256a2fb1812ae026e09dd493661e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sun, 31 May 2026 21:47:29 +0300 Subject: [PATCH 15/32] feat: Add gaussian splat nodes (#14190) --- comfy_api/latest/__init__.py | 3 +- comfy_api/latest/_io.py | 7 +- comfy_api/latest/_util/__init__.py | 3 +- comfy_api/latest/_util/geometry_types.py | 23 +- comfy_extras/nodes_gaussian_splat.py | 1663 ++++++++++++++++++++++ comfy_extras/nodes_save_3d.py | 22 +- nodes.py | 1 + 7 files changed, 1714 insertions(+), 8 deletions(-) create mode 100644 comfy_extras/nodes_gaussian_splat.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index e0a585b10..294ad425e 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -143,6 +143,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + SPLAT = SPLAT File3D = File3D diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e03bafcde..a3aa508ce 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D class FolderType(str, Enum): @@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH +@comfytype(io_type="SPLAT") +class Splat(ComfyTypeIO): + Type = SPLAT + @comfytype(io_type="FILE_3D") class File3DAny(ComfyTypeIO): @@ -2320,6 +2324,7 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "Splat", "File3DAny", "File3DGLB", "File3DGLTF", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 115baf392..b27f5a97e 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH, File3D +from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG __all__ = [ @@ -9,6 +9,7 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SPLAT", "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index cdde60b10..84a18d69a 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -11,13 +11,32 @@ class VOXEL: self.data = data +class SPLAT: + """A batch of 3D Gaussian splats in render-ready (activated, world-space) form. + + Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the + real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are + stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :]. + """ + + def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor, + opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None): + self.positions = positions # (B, N, 3) world-space centers + self.scales = scales # (B, N, 3) linear (positive) per-axis std + self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized) + self.opacities = opacities # (B, N, 1) in [0, 1] + self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients + self.counts = counts # (B,) real lengths, or None + + class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, uvs: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None, texture: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, - face_counts: torch.Tensor | None = None): + face_counts: torch.Tensor | None = None, + unlit: bool = False): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -30,6 +49,8 @@ class MESH: # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. self.vertex_counts = vertex_counts self.face_counts = face_counts + # Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes. + self.unlit = unlit class File3D: diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py new file mode 100644 index 000000000..7fb878b8b --- /dev/null +++ b/comfy_extras/nodes_gaussian_splat.py @@ -0,0 +1,1663 @@ +# Generic utility nodes for the SPLAT type (3D gaussian splats) + +import gzip +import logging +import math +import struct +from io import BytesIO + +import numpy as np +import torch +from typing_extensions import override +from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum +from scipy.sparse import coo_matrix +from scipy.sparse.csgraph import connected_components + +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch +from server import PromptServer + +_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB + + +def _srgb_to_linear(c): + return torch.where(c <= 0.04045, c / 12.92, ((c.clamp_min(0) + 0.055) / 1.055) ** 2.4) + + +def _linear_to_srgb(c): + return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055) + + +def _real_len(g: Types.SPLAT, i: int) -> int: + # Real splat count of batch item i (honors variable-length `counts`). + return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1] + + +def _hex_to_rgb(h: str) -> tuple[float, float, float]: + # "#RRGGBB" -> (r,g,b) in [0,1]; falls back to black. + h = h.lstrip("#") + if len(h) != 6: + return (0.0, 0.0, 0.0) + return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) + + +def _quantile(x, q): + # torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate. + lim = 1 << 24 + if x.numel() > lim: + x = x[:: x.numel() // lim + 1] + return torch.quantile(x, q) + + +def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize render-ready gaussian tensors as a binary 3DGS .ply. + + positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1]; + sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention + (log scale, logit opacity). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + normals = np.zeros_like(xyz) + f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) + f_dc = f[:, 0, :] # (N, 3) + f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6) + op = np.log(op / (1.0 - op)) # inverse sigmoid (logit) + scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8)) + rot = rotations.cpu().numpy().astype(np.float32) # (N, 4) + + attrs = (['x', 'y', 'z', 'nx', 'ny', 'nz'] + + [f'f_dc_{i}' for i in range(3)] + + [f'f_rest_{i}' for i in range(f_rest.shape[1])] + + ['opacity'] + [f'scale_{i}' for i in range(3)] + [f'rot_{i}' for i in range(4)]) + elements = np.empty(n, dtype=[(a, 'f4') for a in attrs]) + elements[:] = list(map(tuple, np.concatenate([xyz, normals, f_dc, f_rest, op, scale, rot], axis=1))) + + header = "ply\nformat binary_little_endian 1.0\n" + f"element vertex {n}\n" + header += "".join(f"property float {a}\n" for a in attrs) + "end_header\n" + return header.encode('ascii') + elements.tobytes() + + +# .ksplat (mkkellogg SplatBuffer) level 0, SH degree 0: 4096-byte header, one 1024-byte section header, +# then N 44-byte records. Bucketing/quantization only exist at levels >= 1. See SplatBuffer.js. +_KSPLAT_HEADER_BYTES = 4096 +_KSPLAT_SECTION_HEADER_BYTES = 1024 +_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4 +_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion + + +def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize gaussian tensors as a level-0, SH degree-0 .ksplat (linear scale, opacity in color alpha). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + scale = scales.cpu().numpy().astype(np.float32) + rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rgb = np.clip(sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5, 0, 1) + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(0, 1) + rgba = np.round(np.concatenate([rgb, op], axis=1) * 255.0).astype(np.uint8) # (N, 4) RGBA + + # 44-byte record: float center(3) + scale(3) + rot(4), then uint8 rgba(4). + floats = np.concatenate([xyz, scale, rot], axis=1).astype(' bytes: + """Serialize gaussian tensors as a gzip-compressed .spz (Niantic v2, SH degree 0, base color only). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + + # Positions: fixed point, masked to 24 bits, little-endian 3-byte words. + fixed = 1 << _SPZ_FRACTIONAL_BITS + qi = np.clip(np.round(xyz * fixed), -(1 << 23), (1 << 23) - 1).astype(np.int32) + qu = (qi & 0xFFFFFF).astype(np.uint32) + pos = np.stack([qu & 0xFF, (qu >> 8) & 0xFF, (qu >> 16) & 0xFF], axis=-1).reshape(n, 9).astype(np.uint8) + + alpha = np.round(opacities.cpu().numpy().astype(np.float32).reshape(n) * 255.0).clip(0, 255).astype(np.uint8) + + rgb = sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5 + col = np.round(((rgb - 0.5) / _SPZ_COLOR_SCALE + 0.5) * 255.0).clip(0, 255).astype(np.uint8) # (N,3) + + sln = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-9)) + scb = np.round((sln + 10.0) * 16.0).clip(0, 255).astype(np.uint8) # (N,3) inverts exp(b/16-10) + + rot = rotations.cpu().numpy().astype(np.float32) # wxyz + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rot[rot[:, 0] < 0] *= -1.0 # canonical w >= 0 (w dropped on decode) + rotb = np.round((rot[:, 1:4] + 1.0) * 127.5).clip(0, 255).astype(np.uint8) # (N,3) x,y,z + + header = bytearray(16) + struct.pack_into(' (positions, scales linear, rotations wxyz, opacities [0,1], sh (N,K,3)) ---- +# Inverse of the writers above and of spark's loaders. ksplat/splat/spz carry base color only (SH degree 0 +# -> K=1); .ply round-trips full SH. None of the formats flip axes, so import is the identity of export. +_PLY_DTYPES = {'char': 'i1', 'uchar': 'u1', 'short': 'i2', 'ushort': 'u2', 'int': 'i4', 'uint': 'u4', + 'float': 'f4', 'double': 'f8', 'int8': 'i1', 'uint8': 'u1', 'int16': 'i2', 'uint16': 'u2', + 'int32': 'i4', 'uint32': 'u4', 'float32': 'f4', 'float64': 'f8'} +_KSPLAT_COMPRESSION = { # level -> (bytesPerCenter, scale, rotation, color, shComponent, defaultScaleRange) + 0: (12, 12, 16, 4, 4, 1), 1: (6, 6, 8, 4, 2, 32767), 2: (6, 6, 8, 4, 1, 32767)} +_KSPLAT_SH_COMPONENTS = {0: 0, 1: 9, 2: 24, 3: 45} + + +def _rgb_to_sh_dc(rgb): + return ((np.asarray(rgb, np.float32) - 0.5) / _C0)[:, None, :] # (N,3) base color -> (N,1,3) SH DC + + +def _norm_quat(q): + return q / np.linalg.norm(q, axis=1, keepdims=True).clip(1e-12) + + +def _parse_ply_gaussian(data: bytes): + end = data.find(b'end_header') + if end < 0: + raise ValueError("File3DToSplat: not a PLY (missing end_header)") + header = data[:end].decode('ascii', 'replace') + body = end + len(b'end_header') + body += 2 if data[body:body + 2] == b'\r\n' else 1 + count, props, in_vertex = 0, [], False + for line in header.splitlines(): + p = line.split() + if not p: + continue + if p[0] == 'format' and p[1] != 'binary_little_endian': + raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)") + if p[0] == 'element': + in_vertex = p[1] == 'vertex' + if in_vertex: + count = int(p[2]) + elif p[0] == 'property' and in_vertex: + if p[1] == 'list': + raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)") + props.append((p[2], '<' + _PLY_DTYPES[p[1]])) + arr = np.frombuffer(data, np.dtype(props), count=count, offset=body) + names = arr.dtype.names + c = lambda k: arr[k].astype(np.float32) + n = count + + xyz = np.stack([c('x'), c('y'), c('z')], 1) + if 'scale_0' in names: + scale = np.exp(np.stack([c('scale_0'), c('scale_1'), c('scale_2')], 1)) # 3DGS stores log scale + else: + scale = np.full((n, 3), 0.01, np.float32) + if 'rot_0' in names: + rot = _norm_quat(np.stack([c('rot_0'), c('rot_1'), c('rot_2'), c('rot_3')], 1)) # wxyz + else: + rot = np.tile(np.array([1, 0, 0, 0], np.float32), (n, 1)) + opacity = 1.0 / (1.0 + np.exp(-c('opacity'))) if 'opacity' in names else np.ones(n, np.float32) + + if 'f_dc_0' in names: + dc = np.stack([c('f_dc_0'), c('f_dc_1'), c('f_dc_2')], 1) # (N,3) + rest = sorted((k for k in names if k.startswith('f_rest_')), key=lambda s: int(s.split('_')[-1])) + if rest: + r = np.stack([c(k) for k in rest], 1) # (N, 3*(K-1)) channel-major + kk = r.shape[1] // 3 + 1 + r = r.reshape(n, 3, kk - 1).transpose(0, 2, 1) # -> (N, K-1, 3) + sh = np.concatenate([dc[:, None, :], r], 1) + else: + sh = dc[:, None, :] + elif 'red' in names: + sh = _rgb_to_sh_dc(np.stack([c('red'), c('green'), c('blue')], 1) / 255.0) + else: + sh = np.zeros((n, 1, 3), np.float32) + return xyz, scale, rot, opacity, sh + + +def _parse_splat_gaussian(data: bytes): + # antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz). + if len(data) % 32 != 0: + raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes") + rec = np.frombuffer(data, np.dtype([('xyz', ' 0: + ct, ft = (' full_splats: + lengths = np.frombuffer(data, '> 30) & 3 + q = np.zeros((n, 4), np.float32) # x,y,z,w + remaining, sumsq = combined.copy(), np.zeros(n, np.float64) + for comp in (3, 2, 1, 0): + active = comp != largest + value = (remaining & 0x1FF).astype(np.float64) + sign = (remaining >> 9) & 1 + remaining = np.where(active, remaining >> 10, remaining) + val = (1.0 / math.sqrt(2)) * (value / 0x1FF) + val = np.where(sign == 1, -val, val) + q[active, comp] = val[active] + sumsq += np.where(active, val * val, 0.0) + q[np.arange(n), largest] = np.sqrt(np.clip(1.0 - sumsq, 0, None)) + rot = _norm_quat(np.stack([q[:, 3], q[:, 0], q[:, 1], q[:, 2]], 1)) # xyzw -> wxyz + else: + qb = np.frombuffer(raw, np.uint8, count=n * 3, offset=off).reshape(n, 3).astype(np.float32) + xq = qb / 127.5 - 1.0 + w = np.sqrt(np.clip(1.0 - (xq ** 2).sum(1), 0, None)) + rot = _norm_quat(np.concatenate([w[:, None], xq], 1)) # wxyz + return xyz, scale, rot, alpha, _rgb_to_sh_dc(rgb) + + +_GAUSSIAN_PARSERS = {"ply": _parse_ply_gaussian, "splat": _parse_splat_gaussian, + "ksplat": _parse_ksplat_gaussian, "spz": _parse_spz_gaussian} + + +def _detect_splat_format(data: bytes) -> str: + if data[:3] == b'ply': + return "ply" + if data[:2] == b'\x1f\x8b': # gzip -> spz + return "spz" + if len(data) >= 2 and data[0] == 0 and data[1] >= 1: # ksplat version 0.x header + return "ksplat" + if len(data) % 32 == 0: + return "splat" + raise ValueError("File3DToSplat: could not determine splat format from contents") + + +def _gaussian_item(g: Types.SPLAT, i: int, device): + # Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB). + end = _real_len(g, i) + to = lambda a: a.to(device=device, dtype=torch.float32) + xyz = to(g.positions[i, :end]) + rgb = (to(g.sh[i, :end, 0, :]) * _C0 + 0.5).clamp(0, 1) + opacity = to(g.opacities[i, :end]).reshape(-1) + scale = to(g.scales[i, :end]) + rot = to(g.rotations[i, :end]) + return xyz, rgb, opacity, scale, rot + + +def _quat_to_mat(q): + # q: (N, 4) wxyz, normalized -> (N, 3, 3) + q = q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + return torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=-1).reshape(-1, 3, 3) + + +def _quat_mul(a, b): + # Hamilton product a (x) b, wxyz. + aw, ax, ay, az = a.unbind(-1) + bw, bx, by, bz = b.unbind(-1) + return torch.stack([ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], dim=-1) + + +def _euler_to_quat(rx, ry, rz): + # Degrees, applied as Rz @ Ry @ Rx (rotate about X, then Y, then Z in world). Returns wxyz. + c, s = np.cos(np.radians([rx, ry, rz]) / 2.0), np.sin(np.radians([rx, ry, rz]) / 2.0) + qx = torch.tensor([c[0], s[0], 0.0, 0.0], dtype=torch.float32) + qy = torch.tensor([c[1], 0.0, s[1], 0.0], dtype=torch.float32) + qz = torch.tensor([c[2], 0.0, 0.0, s[2]], dtype=torch.float32) + return _quat_mul(_quat_mul(qz, qy), qx) + + +def _mat_to_quat(m): + # Rotation matrix (..., 3, 3) -> quaternion (..., 4) wxyz. Batched; builds the four candidate quaternions + # and keeps the one with the largest component (numerically stable across all rotations). + m00, m11, m22 = m[..., 0, 0], m[..., 1, 1], m[..., 2, 2] + m21, m12 = m[..., 2, 1], m[..., 1, 2] + m02, m20 = m[..., 0, 2], m[..., 2, 0] + m10, m01 = m[..., 1, 0], m[..., 0, 1] + q2 = torch.stack([1 + m00 + m11 + m22, 1 + m00 - m11 - m22, + 1 - m00 + m11 - m22, 1 - m00 - m11 + m22], -1) # 4 * (w^2, x^2, y^2, z^2) + cand = torch.stack([ + torch.stack([q2[..., 0], m21 - m12, m02 - m20, m10 - m01], -1), + torch.stack([m21 - m12, q2[..., 1], m10 + m01, m02 + m20], -1), + torch.stack([m02 - m20, m10 + m01, q2[..., 2], m12 + m21], -1), + torch.stack([m10 - m01, m02 + m20, m12 + m21, q2[..., 3]], -1), + ], -2) # (...,4,4) candidates, rows = wxyz + sel = q2.argmax(-1) + q = torch.gather(cand, -2, sel[..., None, None].expand(sel.shape + (1, 4)))[..., 0, :] + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +class SplatToFile3D(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToFile3D", + display_name="Create 3D File (from Splat)", + search_aliases=["gaussian to ply", "splat to file", "export gaussian"], + category="3d/splat", + description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. " + "Supports one item per batch only.", + inputs=[ + IO.Splat.Input("splat"), + IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it + tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. " + "ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only " + "spz: Niantic gzip-compressed (~10x smaller), base color only " + ), + ], + outputs=[IO.File3DAny.Output(display_name="model_3d")], + ) + + @classmethod + def execute(cls, splat, format="ply") -> IO.NodeOutput: + if splat.positions.shape[0] > 1: + logging.warning("SplatToFile3D supports one item per batch only. Got %d; using first.", splat.positions.shape[0]) + end = _real_len(splat, 0) + writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes) + data = writer(splat.positions[0, :end], splat.scales[0, :end], + splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end]) + return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format)) + + +class File3DToSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="File3DToSplat", + display_name="Get Splat", + search_aliases=["load splat", "ply to splat", "import splat", "file to splat"], + category="3d/splat", + description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). " + "Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, " + "the other formats are base color only. Format is auto-detected from the file contents.", + inputs=[ + IO.MultiType.Input( + IO.File3DAny.Input("model_3d"), + types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + tooltip="A gaussian splat 3D file", + ), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput: + data = model_3d.get_bytes() + fmt = (model_3d.format or "").lower() + parser = _GAUSSIAN_PARSERS.get(fmt) or _GAUSSIAN_PARSERS[_detect_splat_format(data)] + xyz, scale, rot, opacity, sh = parser(data) + + t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float() + splat = Types.SPLAT( + t(xyz)[None], # (1, N, 3) + t(scale)[None], # (1, N, 3) linear + t(rot)[None], # (1, N, 4) wxyz + t(opacity).reshape(1, -1, 1), # (1, N, 1) + t(sh)[None], # (1, N, K, 3) + ) + return IO.NodeOutput(splat) + + +def _view_matrix_t(yaw_deg, pitch_deg, device): + y, p = math.radians(yaw_deg), math.radians(pitch_deg) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], device=device) + Rx = torch.tensor([[1, 0, 0], [0, cp, -sp], [0, sp, cp]], device=device) + return Rx @ Ry + + +def _camera_basis(camera_info, dev): + # Look-at basis in the splat frame, named by their projection rows: right = image +x, up = image +y + # (down, since yflip=1), fwd = view/depth axis (eye -> scene). Load3D is three.js (right-handed, Y-up, + # camera looks down -Z); the splat is 3DGS (Y-down, Z-forward). World -> splat is a 180 deg rotation + # about X: (x, y, z) -> (x, -y, -z) (det +1, no mirror, no axis swap). + pos, tgt = camera_info.get("position", {}), camera_info.get("target", {}) + m = lambda d: torch.tensor([float(d.get("x", 0.0)), -float(d.get("y", 0.0)), -float(d.get("z", 0.0))], device=dev) + eye, target = m(pos), m(tgt) + mv = lambda v: torch.stack([v[0], -v[1], -v[2]]) # same world->splat map, for direction vectors + n = lambda v: v / v.norm().clamp_min(1e-8) + q = camera_info.get("quaternion") + if q: # exact camera world rotation (incl. roll) + qwxyz = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + R = _quat_to_mat(qwxyz[None])[0] # columns = camera world axes; looks down local -Z + right = n(mv(R[:, 0])) # camera +X -> image right + up = n(mv(-R[:, 1])) # camera +Y is image up; image-down row is its negative + fwd = n(mv(-R[:, 2])) # camera looks down local -Z -> view direction + return eye, target, right, up, fwd + fwd = n(target - eye) # no quaternion: orbit-consistent, roll-free + yaw = math.degrees(math.atan2(-float(fwd[0]), float(fwd[2]))) + pitch = math.degrees(math.asin(max(-1.0, min(1.0, float(fwd[1]))))) + W = _view_matrix_t(yaw, pitch, dev) + return eye, target, W[0], W[1], W[2] + + +def _lookat_quat_wxyz(position, target, dev): + # three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz. + z = position - target + z = z / z.norm().clamp_min(1e-8) + up0 = torch.tensor([0.0, 1.0, 0.0], device=dev) + if z.dot(up0).abs() > 0.999: # looking straight up/down + up0 = torch.tensor([0.0, 0.0, 1.0], device=dev) + x = torch.linalg.cross(up0, z) + x = x / x.norm().clamp_min(1e-8) + y = torch.linalg.cross(z, x) + R = torch.stack([x, y, z], dim=1) # columns = camera world axes + return _mat_to_quat(R[None])[0] + + +def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0): + # Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y. + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + tgt = torch.as_tensor(target, dtype=torch.float32, device=dev) + q = _lookat_quat_wxyz(pos, tgt, dev) + if roll: # roll about the view axis (camera local Z) + a = math.radians(roll) + qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev) + q = _quat_mul(q[None], qz[None])[0] + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"): + # camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z). + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + qx, qy, qz, qw = (float(c) for c in quat_xyzw) + qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev) + qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8) + R = _quat_to_mat(qwxyz[None])[0] + tgt = pos - R[:, 2] # look one unit down local -Z + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev): + # Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`. + # World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch). + y, p = math.radians(yaw), math.radians(pitch) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2] + m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse) + return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev) + + +def _orbit_camera_info_yaw(camera_info, angle_deg, dev): + # Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict. + a = math.radians(angle_deg) + ca, sa = math.cos(a), math.sin(a) + v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev) + pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {})) + Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev) + new_pos = tgt + Ry @ (pos - tgt) + q = camera_info.get("quaternion") or {} + qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation + qn = _quat_mul(qy[None], qcur[None])[0] + xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])} + return {**camera_info, "position": xyz(new_pos), + "quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}} + + +def _gauss_blur(x, sigma, dev): + # Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map. + r = max(1, int(round(3 * sigma))) + k = torch.exp(-0.5 * (torch.arange(-r, r + 1, device=dev, dtype=torch.float32) / sigma) ** 2) + k = k / k.sum() + c = x.shape[1] + x = torch.nn.functional.conv2d(x, k.view(1, 1, 1, -1).expand(c, 1, 1, -1), padding=(0, r), groups=c) + x = torch.nn.functional.conv2d(x, k.view(1, 1, -1, 1).expand(c, 1, -1, 1), padding=(r, 0), groups=c) + return x + + +def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info, + sharpen=1.0, headlight_shading=0.0, render_style="color"): + # Perspective-correct anisotropic gaussian splat rasterizer. Each splat is weighted by its 3D Gaussian's + # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` + # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. + dev = comfy.model_management.get_torch_device() + t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev) + idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype() + xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1) + scale, rot = t(scale) * float(splat_scale), t(rot) + do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end + if do_linear: + rgb = _srgb_to_linear(rgb) + flat = width * height + bg_t = t(bg) + bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats + need_depth = render_style == "depth" + need_normal = render_style in ("normal", "clay") or headlight_shading > 0 + + def background_only(): # no splats to rasterize -> just the background + empty mask + img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev) + return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype) + + if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) + return background_only() + + eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) + cam = (xyz - eye) @ W.T + fov = float(camera_info.get("fov", 0) or 0) or 35.0 + zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") + xc, yc, zc = cam.unbind(-1) + + keep = zc > 1e-2 + xc, yc, zc, rgb, opacity, scale, rot = (a[keep] for a in (xc, yc, zc, rgb, opacity, scale, rot)) + if xc.shape[0] == 0: # nothing in front of the camera -> background only + return background_only() + if render_style == "clay": + rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry + + f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom + cx0, cy0 = width / 2, height / 2 + + # Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative + # regularizer for a stable inverse (a pixel-size Mip low-pass would over-thicken flat surfels and blur). + Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera + cam_cov = (Mw * scale.square()[:, None, :]) @ Mw.transpose(1, 2) + cam_cov = cam_cov + (cam_cov.diagonal(dim1=-2, dim2=-1).mean(-1) * 1e-3)[:, None, None] * torch.eye(3, device=dev) + + # Perspective-correct weighting: peak of the 3D Gaussian along each pixel ray. Precompute Si, Si@mu, mu^T Si mu. + mu = torch.stack([xc, yc, zc], -1) + si = torch.linalg.inv(cam_cov) + simu = (si @ mu[:, :, None])[:, :, 0] # (N,3) + musimu = (mu * simu).sum(-1) # (N,) + s00, s01, s02 = si[:, 0, 0], si[:, 0, 1], si[:, 0, 2] + s11, s12, s22 = si[:, 1, 1], si[:, 1, 2], si[:, 2, 2] + simu0, simu1, simu2 = simu.unbind(-1) + if need_normal: # surfel normal = thinnest axis, oriented toward camera + nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal + nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) + + # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). + # The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y. + jm = torch.zeros(xc.shape[0], 2, 3, device=dev) + if is_ortho: # parallel projection: screen = s * (xc, yc) + s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane + cx, cy = cx0 + s * xc, cy0 + s * yc + jm[:, 0, 0] = s + jm[:, 1, 1] = s + else: # perspective: screen = f * (xc, yc) / zc + invz = 1.0 / zc + cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz + jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() + jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square() + cov2 = jm @ cam_cov @ jm.transpose(1, 2) + a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] + max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() + radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() + K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) + + # Per-splat kernel size: bucket splats by radius into a coarse ladder of window sizes (global K stays the cap) so + # small splats (the bulk of it) use a small window. + levels = [L for L in (16, 64, 256) if L < K] + [K] + levels_t = torch.tensor(levels, device=dev, dtype=torch.float32) + grids = [] + for L in levels: + rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32) + gy, gx = torch.meshgrid(rng, rng, indexing="ij") + grids.append((gx.reshape(-1), gy.reshape(-1))) + blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma + + n = zc.shape[0] + ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped + nl = len(levels) + order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs + bounds = torch.linspace(0, n, ns + 1, device=dev).round().long() + rank = torch.empty(n, dtype=torch.long, device=dev) + rank[order] = torch.arange(n, device=dev) # depth rank of each splat + slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1) + key = slab_id * nl + blevel # group by slab, then kernel level (order-free within) + order = torch.argsort(key) + key = key[order] + + cxr, cyr = cx[order].round(), cy[order].round() + s00, s01, s02 = s00[order], s01[order], s02[order] + s11, s12, s22 = s11[order], s12[order], s22[order] + s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms + simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] + opacity, rgb = opacity[order], rgb[order] + zc_o = zc[order] if need_depth else None + nrm_o = nrm[order] if need_normal else None + mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) + + # Pack the per-splat scalars into one tensor so each chunk slices once + common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity] + pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu])) + + # Precompute the (slab, level) run table on-GPU and pull it to the CPU once + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1]) + ks = key[starts] + run_lo = starts.tolist() + [n] + run_lev = (ks % nl).tolist() + run_slab = torch.div(ks, nl, rounding_mode="floor").tolist() + slab_runs = [[] for _ in range(ns)] + for r in range(len(run_lev)): + slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r])) + + def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray + cols = pstack[:, lo:hi, None].unbind(0) + cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms + px = cxr_ + ox[None, :] + py = cyr_ + oy[None, :] + valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat + c02, c12, mx, my, mz = cols[9:] + rx = (px - cx0) / s - mx + ry = (py - cy0) / s - my + rz = -mz + a22rz = a22 * rz + inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz + rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry)) + dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry) + q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0) + else: # perspective ray (dx,dy,1) through the camera origin + su0, su1, su2, mus = cols[9:] + dx, dy = (px - cx0) / f, (py - cy0) / f + dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02) + dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12) + dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy + dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1) + q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0) + alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) + idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) + return idx, alpha + + # Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure + # sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window. + sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more + cacc = torch.zeros((flat, 3), device=dev) + trans = torch.ones((flat,), device=dev) + a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) + tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) + crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour + wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only) + dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth + nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal + zslab = torch.zeros((flat,), device=dev) if need_depth else None + nslab = torch.zeros((flat, 3), device=dev) if need_normal else None + stale = 0 # consecutive fully-occluded slabs -> early-out + for si in range(ns): + runs = slab_runs[si] + if not runs: + continue + a_buf.zero_() + tau_buf.zero_() + crgb.zero_() + if sharp: + wbuf.zero_() + if need_depth: + zslab.zero_() + if need_normal: + nslab.zero_() + for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab + ox, oy = grids[li] + ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size + for lo in range(r_lo, r_hi, ch): + hi = min(lo + ch, r_hi) + idx, alpha = splat(lo, hi, ox, oy) + idx, af = idx.reshape(-1), alpha.reshape(-1) + a_buf.index_add_(0, idx, af) + tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge + apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat + crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3)) + if sharp: + wbuf.index_add_(0, idx, apw.reshape(-1)) + if need_depth: + zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) + if need_normal: + nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) + slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats + front = trans * slab_a + denom = wbuf if sharp else a_buf + cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom) + if need_depth or need_normal: + ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only) + if need_depth: + dacc.addcmul_(front, zslab / ainv) + if need_normal: + nacc.addcmul_(front[:, None], nslab / ainv[:, None]) + trans.mul_(1 - slab_a) + if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more) + if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front + stale += 1 + if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop + break + else: + stale = 0 + + cov = 1 - trans + covg = cov.reshape(height, width) + covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only) + depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None + nrm_map = None + if need_normal: + # Per-splat surfel normals are jittery, so do a masked blur + nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None] + cb = cov.reshape(1, 1, height, width) + nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev) + normal = (nb / cb.clamp_min(1e-6))[0].permute(1, 2, 0) + nrm_map = normal / normal.norm(dim=-1, keepdim=True).clamp_min(1e-6) + + if render_style == "depth": # near = bright, far = dark, 0 off-object + d = torch.zeros(height, width, device=dev) + if bool(covm.any()): + lo, hi = depth_map[covm].min(), depth_map[covm].max() + d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d) + img = d[:, :, None].expand(height, width, 3) + elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer + enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1) + img = enc * covm[:, :, None] + else: # color / clay + img = cacc.reshape(height, width, 3) + if render_style == "clay": # studio key light + ambient -> sculpted matte look + kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer + kl = kl / kl.norm() + hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side + img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key + elif headlight_shading > 0: # camera headlight: darken faces turned from view + k = float(headlight_shading) + ndotl = (-nrm_map[:, :, 2]).clamp(0, 1) + img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None] + img = img.addcmul_(trans.reshape(height, width, 1), bg_comp) + if do_linear: # back to display space after linear compositing + img = _linear_to_srgb(img) + return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype) + + +class RenderSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RenderSplat", + display_name="Render Splat", + search_aliases=["splat to image", "render splat", "gaussian turntable"], + category="3d/splat", + description="Render a gaussian splat as an image with an anisotropic EWA rasterizer (oriented " + "elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a " + "camera_info input (Load / Preview 3D, or a Create Camera Info node); leave it empty to " + "auto-frame the splat. Set frames greater than 1 for a turntable batch of images to feed a Video node.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("width", default=1024, min=64, max=2048, step=8), + IO.Int.Input("height", default=1024, min=64, max=2048, step=8), + IO.Int.Input("frames", default=1, min=-240, max=240, + tooltip="-1, 0, 1 = single still image; >1 = turntable, the camera orbits over a full " + "360 turn (works with any camera_info). Negative value orbits the other way."), + IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, + tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " + "higher = softer/fuller surface)."), + IO.Float.Input("sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Sharpen overlapping splats: 1.0 = physically-correct blend; higher biases " + "each pixel toward its dominant (nearest) splat for crisper texture, without " + "shrinking splats or opening gaps. Non-physical above 1."), + IO.Float.Input("headlight_shading", default=0.0, min=0.0, max=3.0, step=0.05, advanced=True, + tooltip="Diffuse shading from a light at the camera (headlight), using the splat surfel " + "normals: darkens surfaces that turn away from view to reveal form/curvature. " + "0 = flat albedo, 1 = strongest shading."), + IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Cull gaussians with opacity below this (removes faint floaters)."), + IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"], + tooltip="What the image output shows: color, clay (neutral-albedo shaded), " + "depth (near=bright), normal (OpenGL normal map)."), + IO.Color.Input("background", default="#000000"), + IO.Image.Input("bg_image", optional=True, + tooltip="Optional background plate composited behind the splat (overrides the solid " + "background colour). Resized to the render size; a batch is used per frame, " + "a single image for all. color/clay only."), + IO.Load3DCamera.Input("camera_info", optional=True, + tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " + "Info node. If empty, the splat is auto-framed from a default 3/4 view."), + ], + outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")], + ) + + @classmethod + def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading, + opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput: + bg = _hex_to_rgb(background) + bg_imgs = None + if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) + bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled") + bg_imgs = bi.movedim(1, -1).clamp(0, 1) + n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) + orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction + imgs, masks = [], [] + device = comfy.model_management.get_torch_device() + total = splat.positions.shape[0] * n_frames + pbar = comfy.utils.ProgressBar(total) if total > 1 else None + k = 0 + for i in range(splat.positions.shape[0]): + xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device) + if opacity_threshold > 0: + keep = opacity >= opacity_threshold + xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] + base_cam = camera_info + if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat + center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) + extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0] + else torch.tensor(1.0, device=device)) + dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) + base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device) + for fr in range(n_frames): + cam_fr = (base_cam if n_frames == 1 + else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device)) + bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour + img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr, + sharpen=sharpen, headlight_shading=headlight_shading, + render_style=render_style) + imgs.append(img) + masks.append(mask) + k += 1 + if pbar is not None: + pbar.update(1) + return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) + + +class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CreateCameraInfo", + display_name="Create Camera Info", + search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"], + category="3d", + description="Build a camera_info" + "Mode 'orbit' aims with yaw/pitch/distance around the target; " + "'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).", + inputs=[ + IO.DynamicCombo.Input("mode", options=[ + IO.DynamicCombo.Option("orbit", [ + IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), + IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01, + tooltip="Camera distance from the target."), + ]), + IO.DynamicCombo.Option("look_at", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + ]), + IO.DynamicCombo.Option("quaternion", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001, + tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."), + ]), + ], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."), + IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True, + tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the " + "whole camera. Ignored in quaternion mode. Defaults to the origin."), + IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0, + tooltip="Camera roll about the view axis, degrees."), + IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0, + tooltip="Vertical field of view in degrees."), + IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01, + tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."), + IO.Combo.Input("camera_type", options=["perspective", "orthographic"], + tooltip="Projection used by Render Splat: perspective (foreshortening) or orthographic (parallel)."), + ], + outputs=[IO.Load3DCamera.Output(display_name="camera_info")], + ) + + @classmethod + def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput: + dev = comfy.model_management.get_torch_device() + kind = mode["mode"] + if kind == "quaternion": # explicit world position + camera rotation + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]] + return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type)) + target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera + if kind == "orbit": # yaw/pitch/distance about the target (world Y-up) + y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"]) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + d = mode["distance"] + position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] + else: # look_at: explicit world-space camera position + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll)) + + +class TransformSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TransformSplat", + display_name="Transform Splat", + search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"], + category="3d/splat", + description="Translate, rotate, and scale a gaussian splat. " + "Non-uniform scale also reshapes every individual splat, slower process.", + inputs=[ + IO.Splat.Input("splat"), + IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("rotate_x", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_y", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_z", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("scale_x", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splat, translate_x, translate_y, translate_z, + rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput: + pos = splat.positions + dev, dt = pos.device, pos.dtype + q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt) + R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation + D = torch.tensor([scale_x, scale_y, scale_z], dtype=dt, device=dev) + A = D[:, None] * R # diag(D) @ R: per-axis scale after rotation + t = torch.tensor([translate_x, translate_y, translate_z], dtype=dt, device=dev) + + positions = pos @ A.T + t # rotate, scale per-axis, then translate + if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly + scales = splat.scales * scale_x + rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations) + rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12) + else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract + rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation + s2 = splat.scales.reshape(-1, 3).square() + cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes + V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation + scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) + rotations = _mat_to_quat(V).reshape(splat.rotations.shape) + out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh, + counts=getattr(splat, "counts", None)) + return IO.NodeOutput(out) + + +class GetSplatCount(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetSplatCount", + display_name="Get Splat Count", + search_aliases=["splat count", "gaussian count", "number of splats", "splat info"], + category="3d/splat", + description="Returns the number of splats summed across the batch.", + inputs=[IO.Splat.Input("splat")], + outputs=[IO.Splat.Output(display_name="splat"), + IO.Int.Output(display_name="count"), + ], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, splat) -> IO.NodeOutput: + count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0])) + if cls.hidden.unique_id: # show the count inline on the node + PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id) + return IO.NodeOutput(splat, count) + + +def _pad_stack(items, n): + # Stack a list of (Lᵢ, *tail) tensors into (B, n, *tail), zero-padding each row up to n. + tail = items[0].shape[1:] + out = items[0].new_zeros((len(items), n, *tail)) + for i, t in enumerate(items): + out[i, :t.shape[0]] = t + return out + + +def _merge_gaussians(gaussians: list) -> Types.SPLAT: + # Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree. + gs = [g for g in gaussians if g is not None] + if not gs: + raise ValueError("MergeSplat: no gaussians to merge") + b = gs[0].positions.shape[0] + for g in gs: + if g.positions.shape[0] != b: + raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).") + max_k = max(g.sh.shape[2] for g in gs) + + pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], [] + for i in range(b): + pos_i, scl_i, rot_i, op_i, sh_i = [], [], [], [], [] + for g in gs: + end = _real_len(g, i) + pos_i.append(g.positions[i, :end]) + scl_i.append(g.scales[i, :end]) + rot_i.append(g.rotations[i, :end]) + op_i.append(g.opacities[i, :end]) + sh = g.sh[i, :end] # (end, K, 3) + if sh.shape[1] < max_k: # zero-pad lower-degree SH + sh = torch.cat([sh, sh.new_zeros(sh.shape[0], max_k - sh.shape[1], sh.shape[2])], dim=1) + sh_i.append(sh) + pos_b.append(torch.cat(pos_i)) + scl_b.append(torch.cat(scl_i)) + rot_b.append(torch.cat(rot_i)) + op_b.append(torch.cat(op_i)) + sh_b.append(torch.cat(sh_i)) + lengths.append(pos_b[-1].shape[0]) + + n = max(lengths) + counts = None + if len(set(lengths)) > 1: + counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64) + return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), + _pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts) + + +class MergeSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + # Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats. + splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32) + return IO.Schema( + node_id="MergeSplat", + display_name="Merge Splats", + search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"], + category="3d/splat", + description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same " + "latent at different seeds densifies the surface, this can improve surface quality when meshing.", + inputs=[IO.Autogrow.Input("splats", template=splats)], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput: + gs = [v for v in splats.values() if v is not None] + if not gs: + raise ValueError("MergeSplat: connect at least one splat.") + return IO.NodeOutput(_merge_gaussians(gs)) + + +def _inverse_covariance(scale, quat): + # Per-splat Sigma^-1 = R diag(1/s^2) R^T. scale (N,3) linear std, quat (N,4) wxyz -> (N,3,3). + q = quat / quat.norm(dim=1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + R = torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=1).reshape(-1, 3, 3) + inv_s2 = 1.0 / scale.clamp_min(1e-8) ** 2 # (N, 3) + return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R) + + +def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None, + col_dtype=torch.float16): + # Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid, + # plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`). + # Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper + # texture). Returns (density, colour numerator, colour normaliser, origin, voxel). + pad = 4.0 * scale.median() + lo = xyz.amin(0) - pad + hi = xyz.amax(0) + pad + voxel = ((hi - lo).max() / res).clamp_min(1e-8) + dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() + + sinv = _inverse_covariance(scale, quat) + kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width + sharp = color_sharpen != 1.0 + vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface) + colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator + wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1) + n, done = xyz.shape[0], 0 + for k in range(1, int(kernel) + 1): + sel = (kreq == k).nonzero(as_tuple=True)[0] + if sel.numel() == 0: + continue + rng = torch.arange(-k, k + 1, device=device, dtype=torch.float32) + off = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), -1).reshape(-1, 3) # (M, 3) + for st in range(0, sel.numel(), chunk): + gi = sel[st:st + chunk] + cc = xyz[gi] + idx = ((cc - lo) / voxel).round()[:, None, :] + off[None] # (b, M, 3) voxel coords + d = (lo + idx * voxel) - cc[:, None, :] # world offset to voxel center + quad = torch.einsum("bmi,bij,bmj->bm", d, sinv[gi], d) + wgt = opacity[gi, None] * torch.exp(-0.5 * quad) + wgt = torch.where(quad < 9.0, wgt, torch.zeros_like(wgt)) # clip beyond 3 sigma + ii = idx.long() + ix = ii[..., 0].clamp(0, dx - 1) + iy = ii[..., 1].clamp(0, dy - 1) + iz = ii[..., 2].clamp(0, dz - 1) + flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1) + vol.index_add_(0, flat, wgt.reshape(-1)) + wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight + colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype)) + if sharp: + wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype)) + done += gi.numel() + if progress is not None: + progress(min(1.0, done / max(1, n))) + colnorm = (wcol if sharp else vol).reshape(dx, dy, dz) # p==1 -> Sum(w) == density + return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel) + + +def _connected_components_gpu(faces, nv): + # FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations. + # Returns per-vertex component labels (min node id, not densified). + a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2) + b = torch.cat([faces[:, 1], faces[:, 2]]) + f = torch.arange(nv, device=faces.device) + while True: + gp = f[f] # grandparent + ga, gb = gp[a], gp[b] + new = f.clone() + new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots + new.scatter_reduce_(0, f[b], ga, "amin", include_self=True) + new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions + new.scatter_reduce_(0, b, ga, "amin", include_self=True) + new = new[new] # shortcut (path compression) + if torch.equal(new, f): + return f + f = new + + +def _clean_components_gpu(verts, faces, min_verts, device): + # GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path + vt = torch.as_tensor(verts, device=device) + ft = torch.as_tensor(faces, device=device) + nv = vt.shape[0] + _, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1 + ncomp = int(label.max()) + 1 + flabel = label[ft[:, 0]] # component id per face + keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if int(keep.sum()) > 1: + fcount = torch.bincount(flabel, minlength=ncomp) + largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax()) + v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]] + cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1)) + idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox + cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True) + cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest) + keep &= ~inner + faces_k = ft[keep[flabel]] + if faces_k.shape[0] == 0: + return verts[:0], faces[:0] + used = torch.unique(faces_k) # sorted, matches np.unique + remap = torch.full((nv,), -1, dtype=torch.int64, device=device) + remap[used] = torch.arange(used.shape[0], device=device) + return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy() + + +def _clean_components(verts, faces, min_verts, device=None): + # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density + # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x + # faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output. + if device is not None and not comfy.model_management.is_device_cpu(device) and \ + comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes + return _clean_components_gpu(verts, faces, min_verts, device) + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False) + flabel = label[faces[:, 0]] # component id per face + keep = np.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if keep.sum() > 1: + fcount = np.bincount(flabel, minlength=ncomp) + largest = np.where(keep, fcount, -1).argmax() + v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]] + cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol + cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at) + cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1) + cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest) + keep &= ~inner + faces = faces[keep[flabel]] + if len(faces) == 0: + return verts[:0], faces + used = np.unique(faces) + remap = np.full(nv, -1, np.int64) + remap[used] = np.arange(len(used)) + return verts[used], remap[faces] + + +def _surface_nets(vol, level, voxel, origin, device): + # Vectorized Surface Nets: one dual vertex per sign-changing cell at its edge-crossing mean, quads wound CCW-outward. + # Returns verts (V,3), faces (F,3). + vol = vol.to(device=device, dtype=torch.float32) + dx, dy, dz = vol.shape + origin_t = torch.as_tensor(origin, device=device, dtype=torch.float32) + empty = (np.zeros((0, 3), np.float32), np.zeros((0, 3), np.int64)) + if dx < 2 or dy < 2 or dz < 2: + return empty + + # Active = cells whose 8 corners aren't all in/all out. + inside = vol >= level # (dx,dy,dz) bool + cs8 = [inside[ox:ox + dx - 1, oy:oy + dy - 1, oz:oz + dz - 1] + for ox, oy, oz in ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), + (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1))] + any_in = cs8[0] | cs8[1] | cs8[2] | cs8[3] | cs8[4] | cs8[5] | cs8[6] | cs8[7] + all_in = cs8[0] & cs8[1] & cs8[2] & cs8[3] & cs8[4] & cs8[5] & cs8[6] & cs8[7] + active = any_in & ~all_in # (cx,cy,cz) straddling cells + nv = int(active.sum()) + if nv == 0: + return empty + + # Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings. + del any_in, all_in, cs8 # corner bool grids no longer needed + ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices + offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device) + offf = offs.to(torch.float32) + edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device) + e0, e1 = edges[:, 0], edges[:, 1] + oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints + + cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too) + loc = [] + for st in range(0, nv, cstep): + ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3) + cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values + csl = cval >= level + v0, v1 = cval[:, e0], cval[:, e1] # (m,12) + cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32) + denom = v1 - v0 + t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1) + pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp) + loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1] + local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3) + verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space + del loc, local, ac + + vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device) + vid[active] = torch.arange(nv, dtype=torch.int32, device=device) + del active + + # Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding. + faces = [] + + def emit(cr, sol, a, b, d, c): + valid = cr & (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0) + if not bool(valid.any()): + return + a, b, c, d, sol = a[valid], b[valid], c[valid], d[valid], sol[valid] + p2, p4 = torch.where(sol, b, c), torch.where(sol, c, b) # reverse quad winding where ~sol + faces.append(torch.stack([a, p2, d], 1)) + faces.append(torch.stack([a, d, p4], 1)) + + a = inside[0:dx - 1, 1:dy - 1, 1:dz - 1] + emit(a != inside[1:dx, 1:dy - 1, 1:dz - 1], a, + vid[:, 0:dy - 2, 0:dz - 2], vid[:, 1:dy - 1, 0:dz - 2], + vid[:, 1:dy - 1, 1:dz - 1], vid[:, 0:dy - 2, 1:dz - 1]) + a = inside[1:dx - 1, 0:dy - 1, 1:dz - 1] + emit(a != inside[1:dx - 1, 1:dy, 1:dz - 1], a, + vid[0:dx - 2, :, 0:dz - 2], vid[0:dx - 2, :, 1:dz - 1], + vid[1:dx - 1, :, 1:dz - 1], vid[1:dx - 1, :, 0:dz - 2]) + a = inside[1:dx - 1, 1:dy - 1, 0:dz - 1] + emit(a != inside[1:dx - 1, 1:dy - 1, 1:dz], a, + vid[0:dx - 2, 0:dy - 2, :], vid[1:dx - 1, 0:dy - 2, :], + vid[1:dx - 1, 1:dy - 1, :], vid[0:dx - 2, 1:dy - 1, :]) + + if not faces: + return empty + return verts.cpu().numpy().astype(np.float32), torch.cat(faces, 0).cpu().numpy().astype(np.int64) + + +def _otsu_level(values, bins=256): + # Otsu threshold: the density value that best splits inside/outside (max between-class variance). + hist, edges = np.histogram(values, bins=bins) + hist = hist.astype(np.float64) + centers = (edges[:-1] + edges[1:]) * 0.5 + w = np.cumsum(hist) # background-class weight at each split + mu = np.cumsum(hist * centers) + wf = w[-1] - w # foreground-class weight + mb = mu / np.where(w > 0, w, 1.0) + mf = (mu[-1] - mu) / np.where(wf > 0, wf, 1.0) + var_b = w * wf * (mb - mf) ** 2 # between-class variance + var_b[(w <= 0) | (wf <= 0)] = -1.0 + return float(centers[int(np.argmax(var_b))]) + + +def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53): + # Taubin lambda|mu smoothing: low-pass the mesh surface without the shrinkage of a Laplacian blur + # (the mu inflation pass cancels the lambda pass's volume loss). Uniform (umbrella) weights. + if iters <= 0 or len(verts) == 0 or len(faces) == 0: + return verts + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency + adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr() + adj.data[:] = 1.0 + deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None] + v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts + for _ in range(int(iters)): + for fac in (lam, mu): + v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v) + return np.ascontiguousarray(v) + + +def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device): + # GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords + # reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy. + dx, dy, dz = colnorm.shape + vt = torch.as_tensor(verts, device=device, dtype=torch.float32) + org = torch.as_tensor(origin, device=device, dtype=torch.float32) + gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z) + size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32) + g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners) + grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x) + + def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device + inp = v.to(device).permute(3, 0, 1, 2)[None].float() + o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True) + return o[0, :, 0, 0, :] + num = samp(colvol) # (3,V) + den = samp(colnorm[..., None]) # (1,V) + return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3) + + +def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): + # Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing -> + # volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface. + rep = progress if progress is not None else (lambda *_: None) + + end = _real_len(g, i) + xyz = g.positions[i, :end].to(device=device, dtype=torch.float32) + scale = g.scales[i, :end].to(device=device, dtype=torch.float32) + quat = g.rotations[i, :end].to(device=device, dtype=torch.float32) + opacity = g.opacities[i, :end].reshape(-1).to(device=device, dtype=torch.float32) + rgb = (g.sh[i, :end, 0, :].to(device=device, dtype=torch.float32) * _C0 + 0.5).clamp(0, 1) + + keep = opacity >= min_opacity + xyz, scale, quat, opacity, rgb = xyz[keep], scale[keep], quat[keep], opacity[keep], rgb[keep] + if xyz.shape[0] == 0: + return None + + vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, + color_sharpen=color_sharpen, + progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% + # Colour: sample on the GPU (grid_sample) when there's headroom + colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 + if colour_gpu: + colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing + colvol_np = colnorm_np = None + else: + colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32) + colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser + del colvol, colnorm # free the colour grids before iso-surfacing + rep(0.40) + + vmin, vmax = float(vol.min()), float(vol.max()) + occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak) + if occ.numel() == 0: + return None + # Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly + # inside the data range so a bias can't push the iso off the histogram. + level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)), + vmax - 1e-6 * (vmax - vmin)) + + # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked + # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM. + sn_dev = device + if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: + sn_dev = torch.device("cpu") + vol = vol.cpu() + verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev) + del vol + rep(0.55) + if min_component > 0 and len(faces) > 0: + verts, faces = _clean_components(verts, faces, min_component, device) + if len(verts) == 0 or len(faces) == 0: + return None + + # Taubin smooths the blocky iso without shrinking it (unlike blurring the density, which rounds features). + verts = _taubin_smooth(verts, faces, taubin) + rep(0.7) + + # Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb) + # and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density + # edge voxels from pulling colours toward black, and matches the gaussians that formed the surface. + if colour_gpu: + col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device) + else: + coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes + num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1) + den = map_coordinates(colnorm_np, coords, order=1, mode="nearest") + col = num / np.clip(den, 1e-8, None)[:, None] + rep(1.0) + + # The unlit material's COLOR_0 is linear and the viewer sRGB-encodes it on output; the splat colours + # are display (sRGB) values, so convert sRGB -> linear here to land at the same brightness as the splat. + col = np.clip(col, 0, 1) + col = np.where(col <= 0.04045, col / 12.92, ((col + 0.055) / 1.055) ** 2.4).astype(np.float32) + + # Splat +Y is glTF's -Y: rotate 180 deg about X (negate Y,Z) to land upright. Proper rotation, so + # winding is kept; done after colouring (which works in the splat frame). + verts = np.ascontiguousarray(verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)) + return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col)) + + +class SplatToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToMesh", + display_name="Extract Mesh from Splat", + search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"], + category="3d/splat", + description="Extract a coloured mesh from a gaussian splat.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("resolution", default=384, min=64, max=768, step=16, + tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " + "more VRAM/time (grows with resolution^3)."), + IO.Int.Input("kernel", default=5, min=1, max=8, + tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window " + "sized to its own 3-sigma, capped here - small surfels stay cheap, large ones " + "aren't truncated. Raise if sparse splats leave gaps."), + IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True, + tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it " + "(volume-preserving), unlike blurring the density. 0 = raw surface."), + IO.Float.Input("level", default=0.4, min=0.0, max=2.0, step=0.01, + tooltip="Iso-surface level. Auto-picked by Otsu; this biases it (1.0 = auto, lower = " + "fatter/more-connected surface, higher = thinner/tighter)."), + IO.Int.Input("min_component", default=500, min=0, max=100000, step=50, advanced=True, + tooltip="Drop connected components smaller than this many vertices (0 = keep all). " + "Removes detached floater blobs and the inner shell of the double wall."), + IO.Float.Input("min_opacity", default=0.02, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Ignore gaussians fainter than this before meshing."), + IO.Float.Input("color_sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Crisp up the vertex texture: 1.0 = physically-correct blend; higher biases " + "each voxel's colour toward its dominant gaussian instead of averaging " + "neighbours (de-smears the texture). Colour only - geometry is unchanged."), + ], + outputs=[IO.Mesh.Output(display_name="mesh")], + ) + + @classmethod + def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: + device = comfy.model_management.get_torch_device() + b = splat.positions.shape[0] + prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block + pbar = comfy.utils.ProgressBar(b * prec) + + verts_l, faces_l, colors_l = [], [], [] + for i in range(b): + cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec)) + res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) + if res is None: + logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i) + v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3)) + else: + v, f, c = res + verts_l.append(v) + faces_l.append(f) + colors_l.append(c) + pbar.update_absolute((i + 1) * prec) # snap to block end (covers empty / early-out splats) + # unlit: render flat (emissive-like) so SaveGLB matches the splat instead of lighting/washing it. + return IO.NodeOutput(pack_variable_mesh_batch(verts_l, faces_l, colors=colors_l, unlit=True)) + + +class GaussianExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat, + GetSplatCount, MergeSplat, SplatToMesh] + + +async def comfy_entrypoint() -> GaussianExtension: + return GaussianExtension() diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index c03524246..a91549e7f 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -16,7 +16,7 @@ from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types -def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. @@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - vertex_counts=vertex_counts, face_counts=face_counts) + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit) def get_mesh_batch_item(mesh, index): @@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index): def save_glb(vertices, faces, filepath, metadata=None, - uvs=None, vertex_colors=None, texture_image=None): + uvs=None, vertex_colors=None, texture_image=None, unlit=False): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None, textures = [] samplers = [] materials = [] + extensions_used = [] + if unlit and texture_png_bytes is None: + # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a + # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. + materials.append({ + "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, + "extensions": {"KHR_materials_unlit": {}}, + "doubleSided": True, + }) + extensions_used.append("KHR_materials_unlit") + primitive["material"] = 0 if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, @@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None, gltf["textures"] = textures if materials: gltf["materials"] = materials + if extensions_used: + gltf["extensionsUsed"] = extensions_used if metadata: gltf["asset"]["extras"] = metadata @@ -376,7 +389,8 @@ class SaveGLB(IO.ComfyNode): save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata, uvs=uvs_i, vertex_colors=v_colors, - texture_image=tex_img) + texture_image=tex_img, + unlit=getattr(mesh, "unlit", False)) results.append({ "filename": f, "subfolder": subfolder, diff --git a/nodes.py b/nodes.py index 528bf316f..5678bc22d 100644 --- a/nodes.py +++ b/nodes.py @@ -2455,6 +2455,7 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", + "nodes_gaussian_splat.py", ] import_failed = [] From 4f7882a7becf96c54a5376a38b8a9649c627c8da Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 1 Jun 2026 06:40:49 +0300 Subject: [PATCH 16/32] [Partner Nodes] feat: added grok-imagine-video-1.5 model to the GrokVideo node in First Frame mode (#14198) --- comfy_api_nodes/nodes_grok.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index a41da42f3..ca8f534ed 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -29,6 +29,11 @@ from comfy_api_nodes.util import ( ) +_GROK_VIDEO_MODEL_API_IDS = { + "grok-imagine-video-1.5": "grok-imagine-video-1.5-preview", +} + + def _extract_grok_price(response) -> float | None: if response.usage and response.usage.cost_in_usd_ticks is not None: return response.usage.cost_in_usd_ticks / 10_000_000_000 @@ -504,7 +509,11 @@ class GrokVideoNode(IO.ComfyNode): category="video/partner/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video"]), + IO.Combo.Input( + "model", + options=["grok-imagine-video", "grok-imagine-video-1.5"], + tooltip="grok-imagine-video-1.5 currently always requires an input image.", + ), IO.String.Input( "prompt", multiline=True, @@ -540,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), - IO.Image.Input("image", optional=True), + IO.Image.Input( + "image", + optional=True, + tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.", + ), ], outputs=[ IO.Video.Output(), @@ -552,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]), expr=""" ( - $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $is15 := $contains(widgets.model, "1.5"); + $rate := $is15 + ? (widgets.resolution = "720p" ? 0.2002 : 0.1144) + : (widgets.resolution = "720p" ? 0.07 : 0.05); + $imgCost := $is15 ? 0.0143 : 0.002; $base := $rate * widgets.duration; - {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + {"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base} ) """, ), @@ -574,6 +591,8 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: + if image is None and model == "grok-imagine-video-1.5": + raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.") image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -584,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), data=VideoGenerationRequest( - model=model, + model=_GROK_VIDEO_MODEL_API_IDS.get(model, model), image=image_url, prompt=prompt, resolution=resolution, @@ -599,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, - price_extractor=_extract_grok_price, + price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) From 70a2e1a8513aada68847ce63eaa2f54ece53a893 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 31 May 2026 20:47:00 -0700 Subject: [PATCH 17/32] Remove old portable updater migration code. (#14202) * Remove old portable updater migration code. This is 2 years old so I don't think it's needed anymore. * Delete new_updater.py --- main.py | 7 ------- new_updater.py | 35 ----------------------------------- 2 files changed, 42 deletions(-) delete mode 100644 new_updater.py diff --git a/main.py b/main.py index bce451a83..239a52013 100644 --- a/main.py +++ b/main.py @@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None): folder_paths.set_temp_directory(temp_dir) cleanup_temp() - if args.windows_standalone_build: - try: - import new_updater - new_updater.update_windows_updater() - except: - pass - if not asyncio_loop: asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) diff --git a/new_updater.py b/new_updater.py deleted file mode 100644 index 9a203acdd..000000000 --- a/new_updater.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import shutil - -base_path = os.path.dirname(os.path.realpath(__file__)) - - -def update_windows_updater(): - top_path = os.path.dirname(base_path) - updater_path = os.path.join(base_path, ".ci/update_windows/update.py") - bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") - - dest_updater_path = os.path.join(top_path, "update/update.py") - dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") - dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") - - try: - with open(dest_bat_path, 'rb') as f: - contents = f.read() - except: - return - - if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): - return - - shutil.copy(updater_path, dest_updater_path) - try: - with open(dest_bat_deps_path, 'rb') as f: - contents = f.read() - contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') - with open(dest_bat_deps_path, 'wb') as f: - f.write(contents) - except: - pass - shutil.copy(bat_path, dest_bat_path) - print("Updated the windows standalone package updater.") # noqa: T201 From 462c27fdb2b84e612cdd4b3c7fac8875b04eda43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:01:50 +0300 Subject: [PATCH 18/32] feat: Add TripoSplat support (#14210) --- comfy/clip_vision.py | 16 +- comfy/image_encoders/dino3.py | 260 ++++++++++++++++++ comfy/latent_formats.py | 10 + comfy/ldm/triposplat/gaussian.py | 199 ++++++++++++++ comfy/ldm/triposplat/model.py | 326 +++++++++++++++++++++++ comfy/ldm/triposplat/preview.py | 91 +++++++ comfy/ldm/triposplat/vae.py | 382 +++++++++++++++++++++++++++ comfy/model_base.py | 19 ++ comfy/model_detection.py | 3 + comfy/sd.py | 11 + comfy/supported_models.py | 25 ++ comfy_extras/nodes_gaussian_splat.py | 3 +- comfy_extras/nodes_triposplat.py | 269 +++++++++++++++++++ nodes.py | 1 + 14 files changed, 1612 insertions(+), 3 deletions(-) create mode 100644 comfy/image_encoders/dino3.py create mode 100644 comfy/ldm/triposplat/gaussian.py create mode 100644 comfy/ldm/triposplat/model.py create mode 100644 comfy/ldm/triposplat/preview.py create mode 100644 comfy/ldm/triposplat/vae.py create mode 100644 comfy_extras/nodes_triposplat.py diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..337575191 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import json import logging +import torch import comfy.ops import comfy.model_patcher @@ -9,6 +10,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,12 +25,16 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel, } class ClipVisionModel(): def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + if isinstance(json_config, dict): + config = json_config + else: + with open(json_config) as f: + config = json.load(f) self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) @@ -44,6 +50,10 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) + if self.model_type == "dinov3" and self.dtype == torch.float16: + # DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice + if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True): + self.dtype = torch.bfloat16 self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() @@ -134,6 +144,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers) + json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..9bd42a66b --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,260 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + + +# DINOv3 ViT-H/16+ (SwiGLU) +DINOV3_VITH_CONFIG = { + "model_type": "dinov3", + "num_hidden_layers": 32, + "hidden_size": 1280, + "num_attention_heads": 20, + "num_register_tokens": 4, + "intermediate_size": 5120, + "layer_norm_eps": 1e-5, + "num_channels": 3, + "patch_size": 16, + "rope_theta": 100.0, + "use_gated_mlp": True, + "gated_mlp_act": "silu", + "image_size": 1024, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], +} + + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs): + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn = optimized_attention_for_device(query_states.device, mask=False) + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device): + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values): + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device) + self.inv_freq = self.inv_freq.to(pixel_values.device) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + cos = torch.cos(angles).to(dtype=pixel_values.dtype) + sin = torch.sin(angles).to(dtype=pixel_values.dtype) + return cos, sin + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values, bool_masked_pos=None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device) + register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + return embeddings + + +class DINOv3ViTLayer(nn.Module): + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, + num_attention_heads, device, dtype, operations, gated_mlp_act="silu"): + super().__init__() + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act) + else: + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + use_gated_mlp = config.get("use_gated_mlp", False) + gated_mlp_act = config.get("gated_mlp_act", "silu") + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, + dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList([ + DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True, + intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, + dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act) + for _ in range(num_hidden_layers)]) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward(self, pixel_values, bool_masked_pos=None, **kwargs): + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings) + + if kwargs.get("skip_norm_elementwise", False): + sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + return sequence_output, None, pooled_output, None diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 12a934d71..bbdfd4bc2 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -239,6 +239,16 @@ class Flux2(LatentFormat): def process_out(self, latent): return latent +class TripoSplat(LatentFormat): + # Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent + latent_channels = 16 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 diff --git a/comfy/ldm/triposplat/gaussian.py b/comfy/ldm/triposplat/gaussian.py new file mode 100644 index 000000000..a4cd2f62f --- /dev/null +++ b/comfy/ldm/triposplat/gaussian.py @@ -0,0 +1,199 @@ +# TripoSplat 3D gaussian container. Operates on already-decoded +# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type. +import torch +import torch.nn.functional as F + +import comfy.model_management + + +class GaussianModel: + def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, opacity_bias: float = 0.1, + scaling_activation: str = "exp", device=None): + self.sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + if scaling_activation == "exp": + self._scaling_activation = torch.exp + self._inverse_scaling_activation = torch.log + elif scaling_activation == "softplus": + self._scaling_activation = F.softplus + self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self._opacity_activation = torch.sigmoid + self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + + self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device) + self.rots_bias = torch.zeros(4, device=self.device) + self.rots_bias[0] = 1 + self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device) + + self._storage = {} + + def _get_store(self, name): + return self._storage.get(name) + + def _set_store(self, name, value): + self._storage[name] = value + + @property + def _xyz(self): + return self._get_store("_xyz") + @_xyz.setter + def _xyz(self, value): + if value is None: + self._set_store("_xyz", None) + self._set_store("xyz", None) + return + self._set_store("_xyz", value) + self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3]) + + @property + def get_xyz(self): + return self._get_store("xyz") + + @property + def _features_dc(self): + return self._get_store("_features_dc") + @_features_dc.setter + def _features_dc(self, value): + self._set_store("_features_dc", value) + + @property + def _opacity(self): + return self._get_store("_opacity") + @_opacity.setter + def _opacity(self, value): + if value is None: + self._set_store("_opacity", None) + self._set_store("opacity", None) + return + self._set_store("_opacity", value) + self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val)) + + @property + def get_opacity(self): + return self._get_store("opacity") + + @property + def _scaling(self): + return self._get_store("_scaling") + @_scaling.setter + def _scaling(self, value): + if value is None: + self._set_store("_scaling", None) + self._set_store("scaling", None) + return + self._set_store("_scaling", value) + s = self._scaling_activation(value + self.scale_bias) + s = torch.square(s) + self.mininum_kernel_size ** 2 + self._set_store("scaling", torch.sqrt(s)) + + @property + def get_scaling(self): + return self._get_store("scaling") + + @property + def _rotation(self): + return self._get_store("_rotation") + @_rotation.setter + def _rotation(self, value): + self._set_store("_rotation", value) + + _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + def render_tensors(self): + # Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform + # (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations. + # Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear, + # rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients. + xyz = self.get_xyz.float() + scaling = self.get_scaling.float() + opacity = self.get_opacity.float() + rotation = (self._rotation + self.rots_bias[None, :]).float() + sh = self._features_dc.float() # (N, K, 3) + T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device) + xyz = xyz @ T.T + rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation))) + rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True) + out_device = comfy.model_management.intermediate_device() + return ( + xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(), + rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(), + sh.to(out_device).contiguous(), + ) + + +def _quat_to_matrix(q): + q = q / torch.linalg.norm(q, dim=-1, keepdim=True) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + R = torch.stack([ + 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y), + 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x), + 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y), + ], dim=-1).reshape(-1, 3, 3) + return R + + +def _matrix_to_quat(R): + trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device) + s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2 + q[:, 0] = 0.25 * s + denom = torch.where(s != 0, s, torch.ones_like(s)) + q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom + q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom + q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom + m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0) + s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2 + q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01] + q[m01, 1] = 0.25 * s1[m01] + q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01] + q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01] + m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0) + s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2 + q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11] + q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11] + q[m11, 2] = 0.25 * s2[m11] + q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11] + m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0) + s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2 + q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21] + q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21] + q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21] + q[m21, 3] = 0.25 * s3[m21] + return q / torch.linalg.norm(q, dim=-1, keepdim=True) + + +def build_gaussian_models(decoder, points_pred: dict, pred: dict): + # Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder + # (carries layout / rep_config / _get_offset) + x = points_pred + offset = decoder._get_offset(pred['features']) + h = pred["features"] + ret = [] + for i in range(h.shape[0]): + g = GaussianModel( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'], + scaling_bias=decoder.rep_config['scaling_bias'], + opacity_bias=decoder.rep_config['opacity_bias'], + scaling_activation=decoder.rep_config['scaling_activation'], + device=h.device, + ) + _x = x["points"][i, :, None, :] + for k, v in decoder.layout.items(): + if k == '_xyz': + setattr(g, k, (offset[i] + _x).flatten(0, 1)) + elif k in ('_xyz_center', '_offset_scale'): + continue + else: + feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + setattr(g, k, feats * decoder.rep_config['lr'][k]) + ret.append(g) + return ret diff --git a/comfy/ldm/triposplat/model.py b/comfy/ldm/triposplat/model.py new file mode 100644 index 000000000..d8a531772 --- /dev/null +++ b/comfy/ldm/triposplat/model.py @@ -0,0 +1,326 @@ +# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and +# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token +# carried as a 2-element nested latent. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.patcher_extension +import comfy.rmsnorm +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim, heads, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device)) + + def forward(self, x): + x = comfy.rmsnorm.rms_norm(x) + return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device) + + +# Positional embeddings + +class RePo3DRotaryEmbedding(nn.Module): + def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0, + dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + repo_hidden_size = int(model_channels * repo_hidden_ratio) + self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device) + self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.act = nn.SiLU() + self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device) + self.dim_0 = 2 * (head_dim // 6) + self.dim_1 = 2 * (head_dim // 6) + self.dim_2 = head_dim - self.dim_0 - self.dim_1 + dims = [self.dim_0, self.dim_1, self.dim_2] + freqs_list = [] + for d in dims: + freq_dim = d // 2 + freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32)) + self.freqs_0 = nn.Parameter(freqs_list[0]) + self.freqs_1 = nn.Parameter(freqs_list[1]) + self.freqs_2 = nn.Parameter(freqs_list[2]) + + def forward(self, hidden_states): + h = self.norm(hidden_states) + feat = self.act(self.gate_map(h)) * self.content_map(h) + out = self.final_map(feat) + B, L, _ = out.shape + delta_pos = out.reshape(B, L, self.num_heads, 3) + f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device) + f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device) + f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device) + ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi + ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi + ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi + ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2) + cos, sin = ang.cos(), ang.sin() + return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2) + + +class PcdAbsolutePositionEmbedder(nn.Module): + # Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat: + # "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders). + def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.max_res = max_res + self.schedule = schedule + self.freq_dim = channels // in_channels // 2 + + def _freqs(self, device): + if self.schedule == "pow2": + freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device) + res_dim = max(0, self.freq_dim - self.max_res) + freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res + if res_dim > 0 else torch.empty(0, device=device)) + freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim] + return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below + logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device) + return torch.pow(2.0, logs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.float() + *dims, D = x.shape + out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi + out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1) + if out.shape[-1] < self.channels: + out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1], + device=out.device, dtype=out.dtype)], dim=-1) + return out.to(orig_dtype) + + +def attention(q, k, v, transformer_options=None): + # q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention. + out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2], + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + transformer_options=transformer_options) + return out.transpose(1, 2) + + +# Transformer building blocks + +class MLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class RopeMultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False, + dtype=None, device=None, operations=None): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.qk_rms_norm = qk_rms_norm + self.use_rope = use_rope + self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, rope_emb=None, transformer_options=None): + B, L, C = x.shape + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + if self.use_rope: + q, k = apply_rope(q, k, rope_emb) + if self.qk_rms_norm: + q = self.q_norm(q) + k = self.k_norm(k) + h = attention(q, k, v, transformer_options) # (B, L, heads, dim) + return self.out(h.reshape(B, L, C)) + + +class UnifiedTransformerBlock(nn.Module): + def __init__(self, channels, num_heads, mlp_ratio=4.0, + use_rope=False, qk_rms_norm=False, qkv_bias=True, + modulation=True, share_mod=False, + dtype=None, device=None, operations=None): + super().__init__() + self.modulation = modulation + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads, + qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm, + dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if modulation: + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device)) + + def forward(self, x, mod=None, rotary_emb=None, transformer_options=None): + if self.modulation: + if not self.share_mod: + mod = self.adaLN_modulation(mod) + mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + else: + x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options) + x = x + self.mlp(self.norm2(x)) + return x + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class LatentSeqMMFlowModel(nn.Module): + def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024, + cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2, + num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128, + mlp_ratio=4, share_mod=True, qk_rms_norm=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.q_token_length = q_token_length + self.in_channels = in_channels + self.cam_channels = cam_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.cond2_channels = cond2_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_refiner_blocks = num_refiner_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + + factory_kwargs = dict(dtype=dtype, device=device) + op_kwargs = dict(operations=operations, **factory_kwargs) + + self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs) + if share_mod: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs)) + + self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs) + self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs) + self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None + + # Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding. + # The embedder is parameter-free and the anchors are fixed, precompute once. + sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length) + pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0)) + self.register_buffer("pos_emb", pos_emb, persistent=False) + + # RePo3DRotaryEmbedding layers for the refiner and main blocks + repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs) + self.noise_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.context_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)]) + + # Refiner blocks + block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs) + self.noise_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)]) + self.context_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)]) + + self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs) + + self.blocks = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)]) + + self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs)) + self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs) + self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs) + + def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, context, ref_latents, transformer_options, **kwargs) + + def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + # x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)]. + # context == feature1. + z, camera = x[0], x[1] + feat1 = context + + h_x = self.input_layer(z) + h_cond = self.cond_embedder(feat1) + if ref_latents is not None and self.cond_embedder2 is not None: + # Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length + # (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context. + feat2 = ref_latents[0].flatten(2).transpose(1, 2) + feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0)) + h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype)) + t_emb = self.t_embedder(t) + t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb + + h_x = h_x + self.pos_emb.to(z) + + for i, block in enumerate(self.noise_refiner): + h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options) + + for i, block in enumerate(self.context_refiner): + h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options) + + cam = camera.to(z) + h_cam = self.cam_refiner(cam) + h = torch.cat([h_x, h_cond, h_cam], dim=1) + + for i, block in enumerate(self.blocks): + h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options) + + h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z) + h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z) + + shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1) + scale = 1 + scale + h_x = torch.addcmul(shift, h_x, scale) + h_cam = torch.addcmul(shift, h_cam, scale) + + return self.out_layer(h_x), self.cam_out_layer(h_cam) diff --git a/comfy/ldm/triposplat/preview.py b/comfy/ldm/triposplat/preview.py new file mode 100644 index 000000000..6a942bb53 --- /dev/null +++ b/comfy/ldm/triposplat/preview.py @@ -0,0 +1,91 @@ +# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera. +import numpy as np +from PIL import Image + +_C0 = 0.28209479177387814 +_LATENT_TOKENS = 8192 # q_token_length +_LATENT_CH = 16 # in_channels +_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame + + +def _view_matrix(yaw_deg, pitch_deg): + y, p = np.radians(yaw_deg), np.radians(pitch_deg) + Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32) + Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32) + return Rx @ Ry + + +def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0, + max_px=9, min_opacity=0.0, fov=35.0, dist=2.2): + # Project gaussian centers with a perspective camera and paint each as a filled disk whose screen + # radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer. + # gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius. + + pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T + v = pts @ _view_matrix(yaw, pitch).T + zc = v[:, 2] + dist + keep = zc > 1e-2 + if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity + keep = keep & (opacity > min_opacity) + v, zc, scale = v[keep], zc[keep], scale[keep] + col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep] + if v.shape[0] == 0: + return Image.fromarray(np.zeros((size, size, 3), np.uint8)) + f = (size / 2) / np.tan(np.radians(fov) / 2) + cx = size / 2 + f * v[:, 0] / zc + cy = size / 2 + f * v[:, 1] / zc + radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32) + + # Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized. + px, py, pz, pc = [], [], [], [] + for r in range(int(radius.min()), int(radius.max()) + 1): + m = radius == r + if not m.any(): + continue + dy, dx = np.mgrid[-r:r + 1, -r:r + 1] + disk = (dx * dx + dy * dy) <= r * r + ox, oy = dx[disk], dy[disk] + px.append((cx[m, None] + ox).ravel()) + py.append((cy[m, None] + oy).ravel()) + pz.append(np.repeat(zc[m], ox.size)) + pc.append(np.repeat(col[m], ox.size, axis=0)) + px, py = np.concatenate(px), np.concatenate(py) + pz, pc = np.concatenate(pz), np.concatenate(pc) + xi = np.clip(px, 0, size - 1).astype(np.int64) + yi = np.clip(py, 0, size - 1).astype(np.int64) + + # Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest + # splat, then decode the winning index back to its color. + pid = yi * size + xi + q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small + key = (q << 32) | np.arange(pid.size, dtype=np.int64) + buf = np.full(size * size, 1 << 62, np.int64) + np.minimum.at(buf, pid, key) + img = np.zeros((size * size, 3), np.uint8) + hit = buf < (1 << 62) + img[hit] = pc[buf[hit] & 0xFFFFFFFF] + return Image.fromarray(img.reshape(size, size, 3)) + + +def _extract_latent(x0): + # x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5); + # the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream. + if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH: + return x0 + flat = x0.reshape(x0.shape[0], -1) + return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH) + + +def decode_x0_to_image(decoder, x0, cfg): + # Decode x0 at a coarse octree level / few gaussians and render a preview image. + latent = _extract_latent(x0) + fsm = decoder.first_stage_model + gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype), + num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0] + xyz = gaussian.get_xyz.float().cpu().numpy() + rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5 + scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis) + opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0] + return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0), + size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3), + min_opacity=0.01) diff --git a/comfy/ldm/triposplat/vae.py b/comfy/ldm/triposplat/vae.py new file mode 100644 index 000000000..e5ed9fd36 --- /dev/null +++ b/comfy/ldm/triposplat/vae.py @@ -0,0 +1,382 @@ +# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an +# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns +# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.ops +from .gaussian import build_gaussian_models +from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention + + +# Quasi-random sampling utilities (pure functions, dtype/device-agnostic) + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[i], n) for i in range(dim)] + + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + + +def sample_probs(probs, counts, generator=None): + # Systematic resampling: distribute counts[r] draws across the P bins of row r + batch_shape = counts.shape + R = counts.numel() + P = probs.size(-1) + device = probs.device + probs = probs.reshape(R, P).to(torch.float32).clamp_min(0) + counts = counts.reshape(R).to(device=device, dtype=torch.long) + + row_sums = probs.sum(1, keepdim=True) + probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1)) + cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12) + + Nmax = int(counts.max()) + if Nmax == 0: + return counts.new_zeros(*batch_shape, P) + cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) + grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) + u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded) + idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) + weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] + out = torch.zeros(R, P, dtype=torch.float32, device=device) + out.scatter_add_(1, idx, weight) + return out.to(torch.long).view(*batch_shape, P) + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False, + dtype=None, device=None, operations=None): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.qk_rms_norm = qk_rms_norm + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, context=None): + B, L, C = x.shape + if self._type == "self": + q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2) + else: + Lkv = context.shape[1] + q = self.to_q(x).reshape(B, L, self.num_heads, -1) + k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = attention(q, k, v) + return self.to_out(h.reshape(B, L, -1)) + + +# Octree probability decoder + +class LevelEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024, + dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + @staticmethod + def level_embedding(t, dim, max_period=1024): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] * 2 * torch.pi + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class ModulatedTransformerCrossOnlyBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False, + qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, + type="cross", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + + def forward(self, x, mod, context): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + return x + + +class OctreeProbabilityFixedlenDecoder(nn.Module): + # Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits. + def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, share_mod=True, + qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm_cross = qk_rms_norm_cross + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device)) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossOnlyBlock( + model_channels, ctx_channels=cond_channels, num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross, + share_mod=self.share_mod, dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device) + self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + + def forward(self, x, l, cond): + d = next(self.parameters()).dtype + B, L, _ = x.shape + h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + l_emb = self.l_embedder(l) + if self.share_mod: + l_emb = self.adaLN_modulation(l_emb) + cond = cond.to(d) + for block in self.blocks: + h = block(h, l_emb, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(d) + logits = self.out_proj(h) + return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} + + @staticmethod + def sample(model, cond, num_points, level, temperature=1.0, generator=None): + B = cond.shape[0] + device = cond.device + child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], + dtype=torch.long, device=device) + prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device) + prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device) + prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device) + batch_indices_range = torch.arange(B, device=device).unsqueeze(1) + + for lv in range(1, level + 1): + res_p = 1 << (lv - 1) + res = 1 << lv + parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p + res_tensor = torch.full((B,), res, dtype=torch.long, device=device) + pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature + pred_probs = torch.softmax(pred_logits, dim=-1) + pred_log_probs = torch.log_softmax(pred_logits, dim=-1) + sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2) + pred_log_probs = pred_log_probs.flatten(1, 2) + prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) + child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) + mask = sampled > 0 + max_valid = mask.sum(dim=1).max().item() + scatter_indices = mask.cumsum(dim=1) - 1 + valid_scatter_indices = scatter_indices[mask] + valid_batch_indices = batch_indices_range.expand_as(mask)[mask] + next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device) + next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask] + next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device) + next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask] + next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device) + next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask] + prev_coords_int = next_prev_coords_int + prev_counts = next_prev_counts + prev_log_probs = next_prev_log_probs + + res = 1 << level + prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) + coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) + rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device) + coords_norm = (coords_int.to(torch.float32) + rand) / res + return {"points": coords_norm, "log_probs": prev_log_probs} + + +# Elastic gaussian decoder + +class TransformerCrossBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, + qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True, + dtype=None, device=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross", + qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, context): + x = x + self.self_attn(self.norm1(x)) + x = x + self.cross_attn(self.norm2(x), context) + x = x + self.mlp(self.norm3(x)) + return x + + +class ElasticGaussianFixedlenDecoder(nn.Module): + # Cross-attention transformer over sampled octree points -> per-point gaussian params. + def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, *, representation_config=None, + qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.rep_config = representation_config or dict( + lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1), + perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32, + filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1, + scaling_activation="softplus", + ) + self.out_channels = self._calc_layout() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + TransformerCrossBlock(model_channels, ctx_channels=cond_channels, + num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, + qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross, + dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device) + self._build_perturbation() + + def _calc_layout(self): + ng = self.rep_config['num_gaussians'] + self.layout = { + '_xyz': {'shape': (ng, 3), 'size': ng * 3}, + '_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3}, + '_scaling': {'shape': (ng, 3), 'size': ng * 3}, + '_rotation': {'shape': (ng, 4), 'size': ng * 4}, + '_opacity': {'shape': (ng, 1), 'size': ng}, + } + self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng} + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + return start + + def _build_perturbation(self): + ng = self.rep_config['num_gaussians'] + perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float() + perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size']) + self.register_buffer('points_offset_perturbation', perturbation) + base = torch.tensor(self.rep_config['offset_scale']) + self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0)) + + def _get_offset(self, h): + B = h.shape[0] + r = self.layout['_offset_scale']['range'] + _offset_scale = F.softplus( + h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape']) + + comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device)) + + r = self.layout['_xyz']['range'] + offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape']) + offset = offset * self.rep_config['lr']['_xyz'] + if self.rep_config['perturb_offset']: + offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device) + offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size'] + offset = offset * _offset_scale + return offset + + def forward(self, x=None, cond=None): + pcd = x["points"] + d = next(self.parameters()).dtype + B, L, _ = pcd.shape + h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + cond = cond.to(d) + for block in self.blocks: + h = block(h, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype) + return {"features": self.out_proj(h)} + + +# Combined octree gaussian decoder (comfy first-stage model) + +class OctreeGaussianDecoder(nn.Module): + _MAX_VOXEL_LEVEL = 8 + + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations) + self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations) + + @property + def gaussians_per_point(self) -> int: + return self.gs.rep_config['num_gaussians'] + + def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None): + # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. + # generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG. + level = self._MAX_VOXEL_LEVEL if level is None else level + num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) + points_pred = OctreeProbabilityFixedlenDecoder.sample( + self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator, + ) + pred = self.gs(x=points_pred, cond=latent) + return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item diff --git a/comfy/model_base.py b/comfy/model_base.py index 205178911..3e2d4e930 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate import comfy.ldm.wan.ar_model import comfy.ldm.wan.model_wandancer import comfy.ldm.hunyuan3d.model +import comfy.ldm.triposplat.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model @@ -1806,6 +1807,24 @@ class Hunyuan3Dv2_1(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class TripoSplat(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context. + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning. + if ref_latents is not None: + out['ref_latents'] = comfy.conds.CONDList(list(ref_latents)) + latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + return out + + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f0db7d388..73354b0d2 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -676,6 +676,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat + return {"image_model": "triposplat"} + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 return {"image_model": "hidream_o1"} diff --git a/comfy/sd.py b/comfy/sd.py index 30b877b85..9a2d31930 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae @@ -894,6 +895,16 @@ class VAE: #Force cast it for --disable-dynamic-vram users until there is a true core fix. if not comfy.memory_management.aimdo_enabled: self.disable_offload = True + elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder + self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder() + self.latent_channels = 16 + self.latent_dim = 1 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian + # decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians. + def _no_generic_io(*args, **kwargs): + raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)") + self.memory_used_encode = self.memory_used_decode = _no_generic_io else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 00941da53..0872b0e27 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1538,6 +1538,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini +class TripoSplat(supported_models_base.BASE): + # Image -> 3D gaussian splat flow denoiser + unet_config = { + "image_model": "triposplat", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 3.0, + } + + memory_usage_factor = 0.6 + + latent_format = latent_formats.TripoSplat + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.TripoSplat(self, device=device) + + def clip_target(self, state_dict={}): + return None + class HiDream(supported_models_base.BASE): unet_config = { "image_model": "hidream", @@ -2200,6 +2224,7 @@ models = [ Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, + TripoSplat, HiDream, HiDreamO1, Chroma, diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 7fb878b8b..2ba3a3820 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -968,7 +968,8 @@ class RenderSplat(IO.ComfyNode): bg = _hex_to_rgb(background) bg_imgs = None if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) - bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled") + bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W) + bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled") bg_imgs = bi.movedim(1, -1).clamp(0, 1) n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py new file mode 100644 index 000000000..021b669fd --- /dev/null +++ b/comfy_extras/nodes_triposplat.py @@ -0,0 +1,269 @@ +# TripoSplat nodes: image -> 3D gaussian splat + +import logging + +import torch +import torch.nn.functional as F +from typing_extensions import override + +import comfy.model_management +import comfy.nested_tensor +import comfy.patcher_extension +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types + + +_Q_TOKEN_LENGTH = 8192 +_LATENT_CHANNELS = 16 +_CAM_CHANNELS = 5 +_DINOV3_MEAN = [0.485, 0.456, 0.406] +_DINOV3_STD = [0.229, 0.224, 0.225] +_NUM_GAUSSIANS_MIN = 32768 +_NUM_GAUSSIANS_MAX = 1048576 + + +def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor: + # Match original preprocessing: + # resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black. + rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W) + alpha = mask.clamp(0, 1)[None] # (1, H, W) + rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W) + + h, w = rgba.shape[-2:] + s = size / min(w, h) + rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1) + + a = rgba[:, 3:4] + if erode_radius > 0: + # min filter over a (2r+1) window == morphological erosion of the alpha matte. + a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius) + rgba = torch.cat([rgba[:, :3], a], 1) + + ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True) + if xs.numel() == 0: + raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).") + x0, x1 = int(xs.min()), int(xs.max()) + y0, y1 = int(ys.min()), int(ys.max()) + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + half = max(x1 - x0, y1 - y0) / 2 * 1.2 + left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half) + + H, W = rgba.shape[-2:] + crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop + sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H) + if sx1 > sx0 and sy1 > sy0: + crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1] + + crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1) + out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha + return out.unsqueeze(0) # (1, 1024, 1024, 3) + + +class TripoSplatPreprocessImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatPreprocessImage", + display_name="TripoSplat Preprocess Image", + category="3d/conditioning", + description="Crop center each image to a square canvas on a black background and add padding.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Int.Input("erode_radius", default=1, min=0, max=16, + tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."), + IO.Int.Input("size", default=1024, min=256, max=4096, step=16, + tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."), + ], + outputs=[IO.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput: + size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16 + if mask.shape[0] != image.shape[0]: + mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0]) + if tuple(mask.shape[1:]) != tuple(image.shape[1:3]): + mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0] + prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0) + return IO.NodeOutput(prepared) + + +class TripoSplatConditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatConditioning", + display_name="TripoSplat Conditioning", + category="3d/conditioning", + description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " + "conditioning, and create the fixed size noise target (latent + camera) for the KSampler", + inputs=[ + IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"), + IO.Vae.Input("vae", tooltip="Flux2 VAE"), + IO.Image.Input("image"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."), + ], + ) + + @classmethod + def execute(cls, clip_vision, vae, image) -> IO.NodeOutput: + # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top + comfy.model_management.load_model_gpu(clip_vision.patcher) + device = clip_vision.load_device + model_dtype = next(clip_vision.model.parameters()).dtype + img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] + mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) + std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) + img = (img - mean) / std + seq = clip_vision.model(pixel_values=img.to(model_dtype))[0] + feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) + + # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry + ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W) + b = ref.shape[0] + + positive = [[feature1, {"reference_latents": [ref]}]] + negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]] + + # Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token + dev = comfy.model_management.intermediate_device() + latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev) + camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev) + samples = comfy.nested_tensor.NestedTensor((latent_seq, camera)) + return IO.NodeOutput(positive, negative, {"samples": samples}) + + +class VAEDecodeTripoSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeTripoSplat", + display_name="TripoSplat Decode", + category="3d/latent", + description="Decode the sampled TripoSplat latent into a 3D gaussian splat. " + "Modify the number of gaussians to vary the density.", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32, + tooltip="Number of gaussians to produce (rounded to a multiple of 32). " + "262144 matches the octree's point density; higher oversamples the same points " + "(denser, but no new detail) and costs proportionally more VRAM/time."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, + tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput: + s = samples["samples"] + latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera + + decoder = vae.first_stage_model + gpp = decoder.gaussians_per_point + n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians))) + if n % gpp != 0: + n = round(n / gpp) * gpp + + dtype_size = comfy.model_management.dtype_size(vae.vae_dtype) + hidden = decoder.gs.model_channels + cond_tokens = latent.shape[1] + memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + latent = latent.to(device=vae.device, dtype=vae.vae_dtype) + generator = torch.Generator(device="cpu").manual_seed(seed) + parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)] + positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) + return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh)) + + +class TripoSplatSamplingPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatSamplingPreview", + display_name="TripoSplat Sampling Preview", + category="3d/latent", + description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded " + "gaussian splat preview at each step.", + inputs=[ + IO.Model.Input("model"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True, + tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."), + IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32, + tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."), + IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,), + IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,), + IO.Int.Input("point_size", default=3, min=1, max=16, + tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; " + "lower = finer/pointier, higher = chunkier."), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput: + from comfy.ldm.triposplat.preview import decode_x0_to_image + cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch, + "point_size": point_size} + + fsm = vae.first_stage_model + cond_tokens = model.model.diffusion_model.q_token_length + memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype) + + # Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar + # The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step + def outer_sample_wrapper(executor, *args, **kwargs): + args = list(args) + cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback") + state = {"ok": True, "pbar": None, "loaded": False} + + def callback(step, x0, x, total_steps): + if orig_cb is not None: + orig_cb(step, x0, x, total_steps) + if not state["ok"]: + return + try: + if not state["loaded"]: + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + state["loaded"] = True + img = decode_x0_to_image(vae, x0, cfg) + if state["pbar"] is None: + state["pbar"] = comfy.utils.ProgressBar(total_steps) + state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512)) + except Exception as e: + logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e)) + state["ok"] = False + + if len(args) > cb_idx: + args[cb_idx] = callback + else: + kwargs["callback"] = callback + return executor(*args, **kwargs) + + m = model.clone() + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper) + return IO.NodeOutput(m) + + +class TripoSplatExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoSplatPreprocessImage, + TripoSplatConditioning, + VAEDecodeTripoSplat, + TripoSplatSamplingPreview, + ] + + +async def comfy_entrypoint() -> TripoSplatExtension: + return TripoSplatExtension() diff --git a/nodes.py b/nodes.py index 5678bc22d..331425b87 100644 --- a/nodes.py +++ b/nodes.py @@ -2456,6 +2456,7 @@ async def init_builtin_extra_nodes(): "nodes_moge.py", "nodes_mediapipe.py", "nodes_gaussian_splat.py", + "nodes_triposplat.py" ] import_failed = [] From af58c5e6744590f06fdf0eacdc9068cd6e8b1c8b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:12:12 +0300 Subject: [PATCH 19/32] [Partner Nodes] feat: add Flux Virtual Try-On and Erase nodes (#14207) --- comfy_api_nodes/apis/bfl.py | 178 ++++++++++++++++------------------- comfy_api_nodes/nodes_bfl.py | 172 +++++++++++++++++++++++++++++++-- 2 files changed, 243 insertions(+), 107 deletions(-) diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index f0665fa09..2ad651122 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -1,71 +1,71 @@ from enum import Enum -from typing import Any, Dict, Optional +from typing import Any -from pydantic import BaseModel, Field, confloat, conint - - -class BFLOutputFormat(str, Enum): - png = 'png' - jpeg = 'jpeg' +from pydantic import BaseModel, Field class BFLFluxExpandImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') - bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') - left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') - right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand') + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + top: int = Field(...) + bottom: int = Field(...) + left: int = Field(...) + right: int = Field(...) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand") class BFLFluxFillImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field( + None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.", ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + mask: str = Field( + None, description="Base64-encoded string representing the mask of the areas you wish to modify." ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + + +class BFLFluxEraseRequest(BaseModel): + image: str = Field(..., description="A Base64-encoded string representing the image to erase from.") + mask: str = Field( + ..., + description="A Base64-encoded black/white mask matching the input dimensions; " + "white (255) marks areas to remove, black (0) marks areas to preserve.", ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') - mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') + dilate_pixels: int = Field(10) + output_format: str = Field("png") + + +class BFLFluxVTORequest(BaseModel): + prompt: str = Field( + ..., description="Natural-language styling instruction. Required field, but may be an empty string." + ) + person: str = Field(..., description="A Base64-encoded string representing the person image.") + garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.") + seed: int | None = Field(None) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') - height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - # image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - # None, description='Blend between the prompt and the image prompt.' - # ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") class Flux2ProGenerateRequest(BaseModel): @@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel): input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") - safety_tolerance: int | None = Field( - 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 - ) - output_format: str | None = Field( - "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." - ) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxKontextProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for what you wannt to edit.') - input_image: Optional[str] = Field(None, description='Image to edit in base64 format') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') - steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=2)] = Field( - 2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) + prompt: str = Field(...) + input_image: str | None = Field(None, description="Image to edit in base64 format") + seed: int | None = Field(None) + guidance: float = Field(...) + steps: int = Field(...) + safety_tolerance: int = Field(2) + output_format: str = Field("png") + aspect_ratio: str | None = Field(None) + prompt_upsampling: bool | None = Field(None) class BFLFluxProUltraGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.') - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - None, description='Blend between the prompt and the image prompt.' - ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + aspect_ratio: str | None = Field(None) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + raw: bool | None = Field(None) + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") + image_prompt_strength: float | None = Field(None) class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - polling_url: str = Field(..., description="URL to poll for the generation result.") + id: str = Field(...) + polling_url: str = Field(...) cost: float | None = Field(None, description="Price in cents") @@ -145,7 +127,7 @@ class BFLStatus(str, Enum): class BFLFluxStatusResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") - progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) + id: str = Field(...) + status: BFLStatus = Field(...) + result: dict[str, Any] | None = Field(None) + progress: float | None = Field(None, ge=0.0, le=1.0) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index f1a5dc5f0..996ab0a27 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -4,17 +4,20 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl import ( + BFLFluxEraseRequest, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, + BFLFluxVTORequest, BFLStatus, Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, + convert_mask_to_image, download_url_to_image_tensor, get_number_of_images, poll_op, @@ -22,19 +25,11 @@ from comfy_api_nodes.util import ( sync_op, tensor_to_base64_string, validate_aspect_ratio_string, + validate_image_dimensions, validate_string, ) -def convert_mask_to_image(mask: Input.Image): - """ - Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. - """ - mask = mask.unsqueeze(-1) - mask = torch.cat([mask] * 3, dim=-1) - return mask - - class FluxProUltraImageNode(IO.ComfyNode): @classmethod @@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class FluxEraseNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxEraseNode", + display_name="Flux Erase Image", + category="image/partner/BFL", + description="Removes the masked object from an image and reconstructs the background. " + "Paint the mask over what you want to erase.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."), + IO.Int.Input( + "dilate_pixels", + default=10, + min=0, + max=25, + tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + mask: Input.Image, + dilate_pixels: int = 10, + ) -> IO.NodeOutput: + validate_image_dimensions(image, min_width=256, min_height=256) + mask = resize_mask_to_image(mask, image) + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxEraseRequest( + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed + mask=mask, + dilate_pixels=dilate_pixels, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + +class FluxVTONode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxVTONode", + display_name="Flux Virtual Try-On", + category="image/partner/BFL", + description="Virtual try-on: dresses the person in the provided garment.", + inputs=[ + IO.Image.Input("person", tooltip="Image of the person to dress."), + IO.Image.Input("garment", tooltip="Image of the garment to apply."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + person: Input.Image, + garment: Input.Image, + prompt: str = "", + seed: int = 0, + ) -> IO.NodeOutput: + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxVTORequest( + prompt=prompt, + person=tensor_to_base64_string(person[:, :, :, :3]), + garment=tensor_to_base64_string(garment[:, :, :, :3]), + seed=seed, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" @@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension): FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + FluxEraseNode, + FluxVTONode, Flux2ProImageNode, Flux2MaxImageNode, Flux2ImageNode, From 412d9ac33a9ba4772e130acb40f0c4b810d9d773 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Mon, 1 Jun 2026 22:41:00 +0800 Subject: [PATCH 20/32] chore: update workflow templates to v0.9.92 (#14212) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 14bba1437..b09d31a8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.91 +comfyui-workflow-templates==0.9.92 comfyui-embedded-docs==0.5.2 torch torchsde From 0b610bd63a3ea2be55e52691881e57b59198c3c2 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Mon, 1 Jun 2026 19:09:57 +0300 Subject: [PATCH 21/32] [Partner Nodes] fix: respect VideoSlice trim when resizing videos (#14213) --- comfy_api/latest/_input/video_types.py | 6 ++++ comfy_api/latest/_input_impl/video_types.py | 6 ++++ comfy_api_nodes/util/conversions.py | 34 +++++++++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 451e9526e..8fff52c16 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -65,6 +65,12 @@ class VideoInput(ABC): buffer.seek(0) return buffer + def get_active_trim_window(self) -> tuple[float, float]: + """Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized + to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override. + """ + return 0.0, 0.0 + # Provide a default implementation, but subclasses can provide optimized versions # if possible. def get_dimensions(self) -> tuple[int, int]: diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 99e67d363..4a12ff9c1 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -75,6 +75,12 @@ class VideoFromFile(VideoInput): self.__file.seek(0) return self.__file + def get_active_trim_window(self) -> tuple[float, float]: + start_time = self.__start_time + if start_time < 0: + start_time = max(self._get_raw_duration() + start_time, 0.0) + return float(start_time), float(self.__duration) + def get_dimensions(self) -> tuple[int, int]: """ Returns the dimensions of the video input. diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 5738df57f..a1b5d599c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input input_container = None output_container = None + # get_stream_source() is untrimmed, so apply the trim window in this same pass. + # start_time is normalized (>= 0); duration == 0 means "until the end". + start_time, duration = video.get_active_trim_window() + trimming = bool(start_time or duration) + try: input_source = video.get_stream_source() input_container = av.open(input_source, mode="r") @@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input audio_stream.layout = stream.layout break + in_video = input_container.streams.video[0] + start_pts = int(start_time / in_video.time_base) if trimming else 0 + end_pts = int((start_time + duration) / in_video.time_base) if duration else None + if start_pts: + input_container.seek(start_pts, stream=in_video) + + encoded = 0 for frame in input_container.decode(video=0): + if trimming: + if frame.pts is None or frame.pts < start_pts: + continue + if end_pts is not None and frame.pts >= end_pts: + break frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + # Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...) + # lets the encoder assign clean ones and avoids mp4 muxer errors. + frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p") for packet in video_stream.encode(frame): output_container.mux(packet) + encoded += 1 for packet in video_stream.encode(): output_container.mux(packet) + if encoded == 0: + raise ValueError( + f"resize produced no frames (start_time={start_time}, duration={duration} " + "selected nothing from the source)" + ) + if audio_stream is not None: input_container.seek(0) for audio_frame in input_container.decode(audio=0): + if trimming: + if audio_frame.time is None or audio_frame.time < start_time: + continue + if duration and audio_frame.time > start_time + duration: + break + # Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI) + audio_frame.pts = None for packet in audio_stream.encode(audio_frame): output_container.mux(packet) for packet in audio_stream.encode(): From a88e02b18576283b1ff25a4b564548c5dc42cbf6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 1 Jun 2026 13:05:25 -0400 Subject: [PATCH 22/32] ComfyUI v0.23.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 0bb0f780c..19e8f8cfc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.22.0" +__version__ = "0.23.0" diff --git a/pyproject.toml b/pyproject.toml index 1e449b4a3..e118800e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.22.0" +version = "0.23.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From e785f0d212731e7f0f4b8c1638c58ab7df6f16b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:35:26 -0700 Subject: [PATCH 23/32] Some cast/dtype fixes for the birefnet and dino3 models. (#14217) --- comfy/background_removal/birefnet.py | 2 +- comfy/clip_vision.py | 5 ----- comfy/image_encoders/dino3.py | 4 +--- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py index df54b2b90..78a80246e 100644 --- a/comfy/background_removal/birefnet.py +++ b/comfy/background_removal/birefnet.py @@ -105,7 +105,7 @@ class WindowAttention(nn.Module): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 337575191..ce8924a11 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,7 +2,6 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import json import logging -import torch import comfy.ops import comfy.model_patcher @@ -50,10 +49,6 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - if self.model_type == "dinov3" and self.dtype == torch.float16: - # DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice - if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True): - self.dtype = torch.bfloat16 self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 9bd42a66b..014d1d29a 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -166,9 +166,8 @@ class DINOv3ViTEmbeddings(nn.Module): def forward(self, pixel_values, bool_masked_pos=None): batch_size = pixel_values.shape[0] - target_dtype = self.patch_embeddings.weight.dtype - patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = self.patch_embeddings(pixel_values) patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) if bool_masked_pos is not None: @@ -244,7 +243,6 @@ class DINOv3ViTModel(nn.Module): return self.embeddings.patch_embeddings def forward(self, pixel_values, bool_masked_pos=None, **kwargs): - pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) position_embeddings = self.rope_embeddings(pixel_values) From 06b710aa685947f3be69da1c95216e63433f5cd1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:35:52 -0700 Subject: [PATCH 24/32] Fix issue with triposplat preview and old offloading mode. (#14218) --- comfy_extras/nodes_triposplat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py index 021b669fd..5646d611b 100644 --- a/comfy_extras/nodes_triposplat.py +++ b/comfy_extras/nodes_triposplat.py @@ -233,7 +233,9 @@ class TripoSplatSamplingPreview(IO.ComfyNode): return try: if not state["loaded"]: - comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + loaded_models.append(vae.patcher) + comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required) state["loaded"] = True img = decode_x0_to_image(vae, x0, cfg) if state["pbar"] is None: From 4b48535a7d66b89a4314e087e70fb7051e54eaaa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:08:20 -0700 Subject: [PATCH 25/32] Do tripo dinov3 inference in fp32. (#14221) --- comfy/image_encoders/dino3.py | 7 ++++--- comfy_extras/nodes_triposplat.py | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 014d1d29a..ad29b06f8 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.ops from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale @@ -171,11 +172,11 @@ class DINOv3ViTEmbeddings(nn.Module): patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) if bool_masked_pos is not None: - mask_token = self.mask_token.to(patch_embeddings.dtype) + mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings) patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) - cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device) - register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device) + cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings) + register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py index 5646d611b..1848ad31a 100644 --- a/comfy_extras/nodes_triposplat.py +++ b/comfy_extras/nodes_triposplat.py @@ -115,12 +115,11 @@ class TripoSplatConditioning(IO.ComfyNode): # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top comfy.model_management.load_model_gpu(clip_vision.patcher) device = clip_vision.load_device - model_dtype = next(clip_vision.model.parameters()).dtype img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) img = (img - mean) / std - seq = clip_vision.model(pixel_values=img.to(model_dtype))[0] + seq = clip_vision.model(pixel_values=img.float())[0] feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry From 33799c4a2ee286b5b6b8aac3c45c43245641fb47 Mon Sep 17 00:00:00 2001 From: vidigoat Date: Tue, 2 Jun 2026 06:45:04 +0530 Subject: [PATCH 26/32] Fix uncaught OverflowError in Math Expression node for large int results (#14214) --- comfy_extras/nodes_math.py | 11 +++++++++-- tests-unit/comfy_extras_test/nodes_math_test.py | 7 +++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 873ee7b51..0883c65ac 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode): f"Math Expression '{expression}' must evaluate to a numeric result, " f"got {type(result).__name__}: {result!r}" ) - if not math.isfinite(result): + try: + float_result = float(result) + except OverflowError: + raise ValueError( + f"Math Expression '{expression}' produced a result too large to " + f"represent as a float: {result}" + ) from None + if not math.isfinite(float_result): raise ValueError( f"Math Expression '{expression}' produced a non-finite result: {result}" ) - return io.NodeOutput(float(result), int(result), bool(result)) + return io.NodeOutput(float_result, int(result), bool(result)) class MathExtension(ComfyExtension): diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py index 714e37c32..030accc5e 100644 --- a/tests-unit/comfy_extras_test/nodes_math_test.py +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -197,3 +197,10 @@ class TestMathExpressionExecute: def test_pow_huge_exponent_raises(self): with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): self._exec("pow(a, b)", a=10, b=10000000) + + def test_huge_int_result_raises_value_error(self): + # Exponent is within the allowed MAX_EXPONENT range, so the result is a + # finite Python int that is nonetheless too large to convert to float. + # This must raise a clean ValueError, not an uncaught OverflowError. + with pytest.raises(ValueError, match="too large to represent as a float"): + self._exec("2 ** 3999") From e88a81d316205bf3c3f9293778b78f5949e82e97 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Mon, 1 Jun 2026 21:24:46 -0700 Subject: [PATCH 27/32] Revert partner nodes category (#14229) --- comfy_api_nodes/nodes_anthropic.py | 2 +- comfy_api_nodes/nodes_beeble.py | 4 +-- comfy_api_nodes/nodes_bfl.py | 16 ++++----- comfy_api_nodes/nodes_bria.py | 6 ++-- comfy_api_nodes/nodes_bytedance.py | 24 ++++++------- comfy_api_nodes/nodes_bytedance_llm.py | 2 +- comfy_api_nodes/nodes_elevenlabs.py | 16 ++++----- comfy_api_nodes/nodes_gemini.py | 12 +++---- comfy_api_nodes/nodes_grok.py | 14 ++++---- comfy_api_nodes/nodes_hitpaw.py | 4 +-- comfy_api_nodes/nodes_hunyuan3d.py | 12 +++---- comfy_api_nodes/nodes_ideogram.py | 6 ++-- comfy_api_nodes/nodes_kling.py | 50 +++++++++++++------------- comfy_api_nodes/nodes_krea.py | 4 +-- comfy_api_nodes/nodes_ltxv.py | 4 +-- comfy_api_nodes/nodes_luma.py | 16 ++++----- comfy_api_nodes/nodes_magnific.py | 10 +++--- comfy_api_nodes/nodes_meshy.py | 14 ++++---- comfy_api_nodes/nodes_minimax.py | 8 ++--- comfy_api_nodes/nodes_openai.py | 14 ++++---- comfy_api_nodes/nodes_openrouter.py | 2 +- comfy_api_nodes/nodes_pixverse.py | 8 ++--- comfy_api_nodes/nodes_quiver.py | 4 +-- comfy_api_nodes/nodes_recraft.py | 38 ++++++++++---------- comfy_api_nodes/nodes_reve.py | 6 ++-- comfy_api_nodes/nodes_rodin.py | 14 ++++---- comfy_api_nodes/nodes_runway.py | 8 ++--- comfy_api_nodes/nodes_sonilo.py | 4 +-- comfy_api_nodes/nodes_sora.py | 2 +- comfy_api_nodes/nodes_stability.py | 16 ++++----- comfy_api_nodes/nodes_topaz.py | 6 ++-- comfy_api_nodes/nodes_tripo.py | 22 ++++++------ comfy_api_nodes/nodes_veo2.py | 6 ++-- comfy_api_nodes/nodes_vidu.py | 26 +++++++------- comfy_api_nodes/nodes_wan.py | 28 +++++++-------- comfy_api_nodes/nodes_wavespeed.py | 4 +-- 36 files changed, 216 insertions(+), 216 deletions(-) diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 7805c96ce..87a870553 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode): return IO.Schema( node_id="ClaudeNode", display_name="Anthropic Claude", - category="text/partner/Anthropic", + category="partner/text/Anthropic", essentials_category="Text Generation", description="Generate text responses with Anthropic's Claude models. " "Provide a text prompt and optionally one or more images for multimodal context.", diff --git a/comfy_api_nodes/nodes_beeble.py b/comfy_api_nodes/nodes_beeble.py index f1082884c..d863c2130 100644 --- a/comfy_api_nodes/nodes_beeble.py +++ b/comfy_api_nodes/nodes_beeble.py @@ -206,7 +206,7 @@ class BeebleSwitchXVideoEdit(IO.ComfyNode): return IO.Schema( node_id="BeebleSwitchXVideoEdit", display_name="Beeble SwitchX Video Edit", - category="video/partner/Beeble", + category="partner/video/Beeble", description=( "Edit a video with Beeble SwitchX. Switches anything in the scene (background, " "lighting, costume) while preserving the original subject's pixels and motion. " @@ -302,7 +302,7 @@ class BeebleSwitchXImageEdit(IO.ComfyNode): return IO.Schema( node_id="BeebleSwitchXImageEdit", display_name="Beeble SwitchX Image Edit", - category="image/partner/Beeble", + category="partner/image/Beeble", description=( "Edit a single image with Beeble SwitchX. Switches anything in the scene " "(background, lighting, costume) while preserving the original subject's pixels. " diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 996ab0a27..79961ff9d 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -37,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode): return IO.Schema( node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( @@ -155,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="image/partner/BFL", + category="partner/image/BFL", description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -277,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode): return IO.Schema( node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), @@ -414,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode): return IO.Schema( node_id="FluxProFillNode", display_name="Flux.1 Fill Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), @@ -521,7 +521,7 @@ class FluxEraseNode(IO.ComfyNode): return IO.Schema( node_id="FluxEraseNode", display_name="Flux Erase Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Removes the masked object from an image and reconstructs the background. " "Paint the mask over what you want to erase.", inputs=[ @@ -597,7 +597,7 @@ class FluxVTONode(IO.ComfyNode): return IO.Schema( node_id="FluxVTONode", display_name="Flux Virtual Try-On", - category="image/partner/BFL", + category="partner/image/BFL", description="Virtual try-on: dresses the person in the provided garment.", inputs=[ IO.Image.Input("person", tooltip="Image of the person to dress."), @@ -697,7 +697,7 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="image/partner/BFL", + category="partner/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input( @@ -868,7 +868,7 @@ class Flux2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Flux2ImageNode", display_name="Flux.2 Image", - category="image/partner/BFL", + category="partner/image/BFL", description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 53e763210..69b0233af 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="BriaImageEditNode", display_name="Bria FIBO Image Edit", - category="image/partner/Bria", + category="partner/image/Bria", description="Edit images using Bria latest model", inputs=[ IO.Combo.Input("model", options=["FIBO"]), @@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveImageBackground", display_name="Bria Remove Image Background", - category="image/partner/Bria", + category="partner/image/Bria", description="Remove the background from an image using Bria RMBG 2.0.", inputs=[ IO.Image.Input("image"), @@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveVideoBackground", display_name="Bria Remove Video Background", - category="video/partner/Bria", + category="partner/video/Bria", description="Remove the background from a video using Bria. ", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 3711bac1d..d8885a7e5 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -368,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageNode", display_name="ByteDance Image", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), @@ -492,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNode", display_name="ByteDance Seedream 4.5 & 5.0", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( @@ -754,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNodeV2", display_name="ByteDance Seedream 4.5 & 5.0", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.String.Input( @@ -920,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceTextToVideoNode", display_name="ByteDance Text to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input( @@ -1048,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageToVideoNode", display_name="ByteDance Image to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on image and prompt", inputs=[ IO.Combo.Input( @@ -1185,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceFirstLastFrameNode", display_name="ByteDance First-Last-Frame to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and first and last frames.", inputs=[ IO.Combo.Input( @@ -1333,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageReferenceNode", display_name="ByteDance Reference Images to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and reference images.", inputs=[ IO.Combo.Input( @@ -1576,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2TextToVideoNode", display_name="ByteDance Seedance 2.0 Text to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 models based on a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1677,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2FirstLastFrameNode", display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", inputs=[ IO.DynamicCombo.Input( @@ -1944,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2ReferenceNode", display_name="ByteDance Seedance 2.0 Reference to Video", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description="Generate, edit, or extend video using Seedance 2.0 with reference images, " "videos, and audio. Supports multimodal reference, video editing, and video extension.", inputs=[ @@ -2241,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateImageAsset", display_name="ByteDance Create Image Asset", - category="image/partner/ByteDance", + category="partner/image/ByteDance", description=( "Create a Seedance 2.0 personal image asset. Uploads the input image and " "registers it in the given asset group. If group_id is empty, runs a real-person " @@ -2308,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateVideoAsset", display_name="ByteDance Create Video Asset", - category="video/partner/ByteDance", + category="partner/video/ByteDance", description=( "Create a Seedance 2.0 personal video asset. Uploads the input video and " "registers it in the given asset group. If group_id is empty, runs a real-person " diff --git a/comfy_api_nodes/nodes_bytedance_llm.py b/comfy_api_nodes/nodes_bytedance_llm.py index 007cac45f..cb41defa0 100644 --- a/comfy_api_nodes/nodes_bytedance_llm.py +++ b/comfy_api_nodes/nodes_bytedance_llm.py @@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedNode", display_name="ByteDance Seed", - category="text/partner/ByteDance", + category="partner/text/ByteDance", essentials_category="Text Generation", description="Generate text responses with ByteDance's Seed 2.0 models. " "Provide a text prompt and optionally one or more images or videos for multimodal context.", diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py index 37eeb2601..eba578a45 100644 --- a/comfy_api_nodes/nodes_elevenlabs.py +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToText", display_name="ElevenLabs Speech to Text", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transcribe audio to text. " "Supports automatic language detection, speaker diarization, and audio event tagging.", inputs=[ @@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsVoiceSelector", display_name="ElevenLabs Voice Selector", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Select a predefined ElevenLabs voice for text-to-speech generation.", inputs=[ IO.Combo.Input( @@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSpeech", display_name="ElevenLabs Text to Speech", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Convert text to speech.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsAudioIsolation", display_name="ElevenLabs Voice Isolation", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Remove background noise from audio, isolating vocals or speech.", inputs=[ IO.Audio.Input( @@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSoundEffects", display_name="ElevenLabs Text to Sound Effects", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate sound effects from text descriptions.", inputs=[ IO.String.Input( @@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsInstantVoiceClone", display_name="ElevenLabs Instant Voice Clone", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Create a cloned voice from audio samples. " "Provide 1-8 audio recordings of the voice to clone.", inputs=[ @@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToSpeech", display_name="ElevenLabs Speech to Speech", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transform speech from one voice to another while preserving the original content and emotion.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToDialogue", display_name="ElevenLabs Text to Dialogue", - category="audio/partner/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", inputs=[ IO.Float.Input( diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 3cfd541b2..e75ef3835 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode): return IO.Schema( node_id="GeminiNode", display_name="Google Gemini", - category="text/partner/Gemini", + category="partner/text/Gemini", description="Generate text responses with Google's Gemini AI model. " "You can provide multiple types of inputs (text, images, audio, video) " "as context for generating more relevant and meaningful responses.", @@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode): return IO.Schema( node_id="GeminiInputFiles", display_name="Gemini Input Files", - category="text/partner/Gemini", + category="partner/text/Gemini", description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " "The files will be read by the Gemini model when generating a response. " "The contents of the text file count toward the token limit. " @@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode): return IO.Schema( node_id="GeminiImageNode", display_name="Nano Banana (Google Gemini Image)", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Edit images synchronously via Google API.", inputs=[ IO.String.Input( @@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode): return IO.Schema( node_id="GeminiImage2Node", display_name="Nano Banana Pro (Google Gemini Image)", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2", display_name="Nano Banana 2", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2V2", display_name="Nano Banana 2", - category="image/partner/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index ca8f534ed..2ae529813 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -54,7 +54,7 @@ class GrokImageNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageNode", display_name="Grok Image", - category="image/partner/Grok", + category="partner/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ IO.Combo.Input( @@ -228,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNode", display_name="Grok Image Edit", - category="image/partner/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.Combo.Input( @@ -369,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNodeV2", display_name="Grok Image Edit", - category="image/partner/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.String.Input( @@ -506,7 +506,7 @@ class GrokVideoNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoNode", display_name="Grok Video", - category="video/partner/Grok", + category="partner/video/Grok", description="Generate video from a prompt or an image", inputs=[ IO.Combo.Input( @@ -630,7 +630,7 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoEditNode", display_name="Grok Video Edit", - category="video/partner/Grok", + category="partner/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ IO.Combo.Input("model", options=["grok-imagine-video"]), @@ -708,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoReferenceNode", display_name="Grok Reference-to-Video", - category="video/partner/Grok", + category="partner/video/Grok", description="Generate video guided by reference images as style and content references.", inputs=[ IO.String.Input( @@ -841,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoExtendNode", display_name="Grok Video Extend", - category="video/partner/Grok", + category="partner/video/Grok", description="Extend an existing video with a seamless continuation based on a text prompt.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index 22e679c29..062d3cf1d 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawGeneralImageEnhance", display_name="HitPaw General Image Enhance", - category="image/partner/HitPaw", + category="partner/image/HitPaw", description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", inputs=[ @@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawVideoEnhance", display_name="HitPaw Video Enhance", - category="video/partner/HitPaw", + category="partner/video/HitPaw", description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " "Prices shown are per second of video.", inputs=[ diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 826a3bd2d..fcd27b7fb 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentTextToModelNode", display_name="Hunyuan3D: Text to Model", - category="3d/partner/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentImageToModelNode", display_name="Hunyuan3D: Image(s) to Model", - category="3d/partner/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.Schema( node_id="TencentModelTo3DUVNode", display_name="Hunyuan3D: Model to UV", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Perform UV unfolding on a 3D model to generate UV texture. " "Input model must have less than 30000 faces.", inputs=[ @@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DTextureEditNode", display_name="Hunyuan3D: 3D Texture Edit", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="After inputting the 3D model, perform 3D model texture redrawing.", inputs=[ IO.MultiType.Input( @@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DPartNode", display_name="Hunyuan3D: 3D Part", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Automatically perform component identification and generation based on the model structure.", inputs=[ IO.MultiType.Input( @@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode): return IO.Schema( node_id="TencentSmartTopologyNode", display_name="Hunyuan3D: Smart Topology", - category="3d/partner/Tencent", + category="partner/3d/Tencent", description="Perform smart retopology on a 3D model. " "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", inputs=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index edd9b9435..8018c3902 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode): return IO.Schema( node_id="IdeogramV1", display_name="Ideogram V1", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V1 model.", inputs=[ IO.String.Input( @@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode): return IO.Schema( node_id="IdeogramV2", display_name="Ideogram V2", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V2 model.", inputs=[ IO.String.Input( @@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode): return IO.Schema( node_id="IdeogramV3", display_name="Ideogram V3", - category="image/partner/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", inputs=[ diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9925ec548..d11e42540 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControls", display_name="Kling Camera Controls", - category="video/partner/Kling", + category="partner/video/Kling", description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", inputs=[ IO.Combo.Input("camera_control_type", options=KlingCameraControlType), @@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoNode", display_name="Kling Text to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Text to Video Node", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProTextToVideoNode", display_name="Kling 3.0 Omni Text to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", display_name="Kling 3.0 Omni First-Last-Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageToVideoNode", display_name="Kling 3.0 Omni Image to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProVideoToVideoNode", display_name="Kling 3.0 Omni Video to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProEditVideoNode", display_name="Kling 3.0 Omni Edit Video", - category="video/partner/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ @@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageNode", display_name="Kling 3.0 Omni Image", - category="image/partner/Kling", + category="partner/image/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), @@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlT2VNode", display_name="Kling Text to Video (Camera Control)", - category="video/partner/Kling", + category="partner/video/Kling", description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingImage2VideoNode", display_name="Kling Image(First Frame) to Video", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlI2VNode", display_name="Kling Image to Video (Camera Control)", - category="video/partner/Kling", + category="partner/video/Kling", description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", inputs=[ IO.Image.Input( @@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingStartEndFrameNode", display_name="Kling Start-End Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", inputs=[ IO.Image.Input( @@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoExtendNode", display_name="Kling Video Extend", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", inputs=[ IO.String.Input( @@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingDualCharacterVideoEffectNode", display_name="Kling Dual Character Video Effects", - category="video/partner/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", inputs=[ IO.Image.Input("image_left", tooltip="Left side image"), @@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingSingleImageVideoEffectNode", display_name="Kling Video Effects", - category="video/partner/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ IO.Image.Input( @@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncAudioToVideoNode", display_name="Kling Lip Sync Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ @@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncTextToVideoNode", display_name="Kling Lip Sync Video with Text", - category="video/partner/Kling", + category="partner/video/Kling", description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ IO.Video.Input("video"), @@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode): return IO.Schema( node_id="KlingVirtualTryOnNode", display_name="Kling Virtual Try On", - category="image/partner/Kling", + category="partner/image/Kling", description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", inputs=[ IO.Image.Input("human_image"), @@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="KlingImageGenerationNode", display_name="Kling 3.0 Image", - category="image/partner/Kling", + category="partner/image/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoWithAudio", display_name="Kling 2.6 Text to Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), @@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingImageToVideoWithAudio", display_name="Kling 2.6 Image(First Frame) to Video with Audio", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Image.Input("start_frame"), @@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode): return IO.Schema( node_id="KlingMotionControl", display_name="Kling Motion Control", - category="video/partner/Kling", + category="partner/video/Kling", inputs=[ IO.String.Input("prompt", multiline=True), IO.Image.Input("reference_image"), @@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoNode", display_name="Kling 3.0 Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3. " "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", inputs=[ @@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingFirstLastFrameNode", display_name="Kling 3.0 First-Last-Frame to Video", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3 using first and last frames.", inputs=[ IO.String.Input("prompt", multiline=True, default=""), @@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode): return IO.Schema( node_id="KlingAvatarNode", display_name="Kling Avatar 2.0", - category="video/partner/Kling", + category="partner/video/Kling", description="Generate broadcast-style digital human videos from a single photo and an audio file.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py index be04a272b..34369f05f 100644 --- a/comfy_api_nodes/nodes_krea.py +++ b/comfy_api_nodes/nodes_krea.py @@ -106,7 +106,7 @@ class Krea2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Krea2ImageNode", display_name="Krea 2 Image", - category="image/partner/Krea", + category="partner/image/Krea", description=( "Generate images via Krea 2 — pick Medium (expressive illustrations) or " "Large (expressive photorealism). Supports an optional moodboard and up " @@ -229,7 +229,7 @@ class Krea2StyleReferenceNode(IO.ComfyNode): return IO.Schema( node_id="Krea2StyleReferenceNode", display_name="Krea 2 Style Reference", - category="image/partner/Krea", + category="partner/image/Krea", description=( "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " "Style Reference nodes (max 10) and feed the final `style_reference` output " diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 01791d354..878e04b4e 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiTextToVideo", display_name="LTXV Text To Video", - category="video/partner/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution.", inputs=[ IO.Combo.Input("model", options=list(MODELS_MAP.keys())), @@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiImageToVideo", display_name="LTXV Image To Video", - category="video/partner/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution based on start image.", inputs=[ IO.Image.Input("image", tooltip="First frame to be used for the video."), diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 08ae9904c..0d31ac77e 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode): return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", - category="image/partner/Luma", + category="partner/image/Luma", description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( @@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode): return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", - category="video/partner/Luma", + category="partner/video/Luma", description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( @@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( @@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", - category="video/partner/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( @@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", - category="video/partner/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( @@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode2", display_name="Luma UNI-1 Image", - category="image/partner/Luma", + category="partner/image/Luma", description="Generate images from text using the Luma UNI-1 model.", inputs=[ IO.String.Input( @@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageEditNode2", display_name="Luma UNI-1 Image Edit", - category="image/partner/Luma", + category="partner/image/Luma", description="Edit an existing image with a text prompt using the Luma UNI-1 model.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index a6aeb194a..4ce4735df 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerCreativeNode", display_name="Magnific Image Upscale (Creative)", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " "Maximum output: 25.3 megapixels.", inputs=[ @@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerPreciseV2Node", display_name="Magnific Image Upscale (Precise V2)", - category="image/partner/Magnific", + category="partner/image/Magnific", description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " "Maximum output: 10060×10060 pixels.", inputs=[ @@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageStyleTransferNode", display_name="Magnific Image Style Transfer", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Transfer the style from a reference image to your input image.", inputs=[ IO.Image.Input("image", tooltip="The image to apply style transfer to."), @@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageRelightNode", display_name="Magnific Image Relight", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Relight an image with lighting adjustments and optional reference-based light transfer.", inputs=[ IO.Image.Input("image", tooltip="The image to relight."), @@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageSkinEnhancerNode", display_name="Magnific Image Skin Enhancer", - category="image/partner/Magnific", + category="partner/image/Magnific", description="Skin enhancement for portraits with multiple processing modes.", inputs=[ IO.Image.Input("image", tooltip="The portrait image to enhance."), diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 4fb670404..3a24f1095 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextToModelNode", display_name="Meshy: Text to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.String.Input("prompt", multiline=True, default=""), @@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRefineNode", display_name="Meshy: Refine Draft Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Refine a previously created draft model.", inputs=[ IO.Combo.Input("model", options=["latest"]), @@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyImageToModelNode", display_name="Meshy: Image to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Image.Input("image"), @@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyMultiImageToModelNode", display_name="Meshy: Multi-Image to Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Autogrow.Input( @@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRigModelNode", display_name="Meshy: Rig Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Provides a rigged character in standard formats. " "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " "or humanoid assets with unclear limb and body structure.", @@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyAnimateModelNode", display_name="Meshy: Animate Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", description="Apply a specific animation action to a previously rigged character.", inputs=[ IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), @@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextureNode", display_name="Meshy: Texture Model", - category="3d/partner/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 338584148..6250af146 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( @@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", - category="video/partner/MiniMax", + category="partner/video/MiniMax", description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 48c739dfe..0fe5fb9d0 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( @@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( @@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", is_deprecated=True, inputs=[ @@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImageNodeV2", display_name="OpenAI GPT Image 2", - category="image/partner/OpenAI", + category="partner/image/OpenAI", description="Generates images via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( @@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatNode", display_name="OpenAI ChatGPT", - category="text/partner/OpenAI", + category="partner/text/OpenAI", essentials_category="Text Generation", description="Generate text responses from an OpenAI model.", inputs=[ @@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode): return IO.Schema( node_id="OpenAIInputFiles", display_name="OpenAI ChatGPT Input Files", - category="text/partner/OpenAI", + category="partner/text/OpenAI", description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", inputs=[ IO.Combo.Input( @@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatConfig", display_name="OpenAI ChatGPT Advanced Options", - category="text/partner/OpenAI", + category="partner/text/OpenAI", description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_openrouter.py b/comfy_api_nodes/nodes_openrouter.py index d2ebbef0d..ba98133f0 100644 --- a/comfy_api_nodes/nodes_openrouter.py +++ b/comfy_api_nodes/nodes_openrouter.py @@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode): return IO.Schema( node_id="OpenRouterLLMNode", display_name="OpenRouter LLM", - category="text/partner/OpenRouter", + category="partner/text/OpenRouter", essentials_category="Text Generation", description=( "Generate text responses through OpenRouter. Routes to a curated set of popular " diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 3861cfedd..4c8b723b9 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTemplateNode", display_name="PixVerse Template", - category="video/partner/PixVerse", + category="partner/video/PixVerse", inputs=[ IO.Combo.Input("template", options=list(pixverse_templates.keys())), ], @@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( @@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), @@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", - category="video/partner/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index ad045a7ef..34929fa0c 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverTextToSVGNode", display_name="Quiver Text to SVG", - category="image/partner/Quiver", + category="partner/image/Quiver", description="Generate an SVG from a text prompt using Quiver AI.", inputs=[ IO.String.Input( @@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverImageToSVGNode", display_name="Quiver Image to SVG", - category="image/partner/Quiver", + category="partner/image/Quiver", description="Vectorize a raster image into SVG using Quiver AI.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 07387821d..c44942f50 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode): return IO.Schema( node_id="RecraftColorRGB", display_name="Recraft Color RGB", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create Recraft Color by choosing specific RGB values.", inputs=[ IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), @@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode): return IO.Schema( node_id="RecraftControls", display_name="Recraft Controls", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create Recraft Controls for customizing Recraft generation.", inputs=[ IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), @@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3RealisticImage", display_name="Recraft Style - Realistic Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3DigitalIllustration", display_name="Recraft Style - Digital Illustration", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3VectorIllustrationNode", display_name="Recraft Style - Realistic Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3LogoRaster", display_name="Recraft Style - Logo Raster", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), @@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), @@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCreateStyleNode", display_name="Recraft Create Style", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Create a custom style from reference images. " "Upload 1-5 images to use as style references. " "Total size of all images is limited to 5 MB.", @@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToImageNode", display_name="Recraft Text to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), @@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageToImageNode", display_name="Recraft Image to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and strength.", inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageInpaintingNode", display_name="Recraft Image Inpainting", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and mask.", inputs=[ IO.Image.Input("image"), @@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToVectorNode", display_name="Recraft Text to Vector", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates SVG synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), @@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", - category="image/partner/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ @@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftReplaceBackgroundNode", display_name="Recraft Replace Background", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Replace background on image, based on provided prompt.", inputs=[ IO.Image.Input("image"), @@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftRemoveBackgroundNode", display_name="Recraft Remove Background", - category="image/partner/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Remove background from image, and return processed image and mask.", inputs=[ @@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCrispUpscaleNode", display_name="Recraft Crisp Upscale Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘crisp upscale’ tool, " "increasing image resolution, making the image sharper and cleaner.", @@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): return IO.Schema( node_id="RecraftCreativeUpscaleNode", display_name="Recraft Creative Upscale Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘creative upscale’ tool, " "boosting resolution with a focus on refining small details and faces.", @@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToImageNode", display_name="Recraft V4 Text to Image", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates images using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( @@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToVectorNode", display_name="Recraft V4 Text to Vector", - category="image/partner/Recraft", + category="partner/image/Recraft", description="Generates SVG using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py index 2b15eadd7..177349a8b 100644 --- a/comfy_api_nodes/nodes_reve.py +++ b/comfy_api_nodes/nodes_reve.py @@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageCreateNode", display_name="Reve Image Create", - category="image/partner/Reve", + category="partner/image/Reve", description="Generate images from text descriptions using Reve.", inputs=[ IO.String.Input( @@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageEditNode", display_name="Reve Image Edit", - category="image/partner/Reve", + category="partner/image/Reve", description="Edit images using natural language instructions with Reve.", inputs=[ IO.Image.Input("image", tooltip="The image to edit."), @@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageRemixNode", display_name="Reve Image Remix", - category="image/partner/Reve", + category="partner/image/Reve", description="Combine reference images with text prompts to create new images using Reve.", inputs=[ IO.Autogrow.Input( diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e14955661..0375a2123 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Regular", display_name="Rodin 3D Generate - Regular Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Detail", display_name="Rodin 3D Generate - Detail Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Smooth", display_name="Rodin 3D Generate - Smooth Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Sketch", display_name="Rodin 3D Generate - Sketch Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen2", display_name="Rodin 3D Generate - Gen-2 Generate", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Image", display_name="Rodin 3D Gen-2.5 - Image to 3D", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." @@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Text", display_name="Rodin 3D Gen-2.5 - Text to 3D", - category="3d/partner/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from a text prompt via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 7357c733e..b9c5c81a1 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen3a", display_name="Runway Image to Video (Gen3a Turbo)", - category="video/partner/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen4", display_name="Runway Image to Video (Gen4 Turbo)", - category="video/partner/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="RunwayFirstLastFrameNode", display_name="Runway First-Last-Frame to Video", - category="video/partner/Runway", + category="partner/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " "More complex transitions, such as cases where the Last frame is completely different " "from the First frame, may benefit from the longer 10s duration. " @@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RunwayTextToImageNode", display_name="Runway Text to Image", - category="image/partner/Runway", + category="partner/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " "You can also include reference image to guide the generation.", inputs=[ diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index bc31a0074..9ce896ed0 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloVideoToMusic", display_name="Sonilo Video to Music", - category="audio/partner/Sonilo", + category="partner/audio/Sonilo", description="Generate music from video content using Sonilo's AI model. " "Analyzes the video and creates matching music.", inputs=[ @@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloTextToMusic", display_name="Sonilo Text to Music", - category="audio/partner/Sonilo", + category="partner/audio/Sonilo", description="Generate music from a text prompt using Sonilo's AI model. " "Leave duration at 0 to let the model infer it from the prompt.", inputs=[ diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 83cfca495..4ff1d649f 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode): return IO.Schema( node_id="OpenAIVideoSora2", display_name="OpenAI Sora - Video (DEPRECATED)", - category="video/partner/Sora", + category="partner/video/Sora", description=( "OpenAI video and audio generation.\n\n" "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index a1753d647..9eaba173b 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageUltraNode", display_name="Stability AI Stable Image Ultra", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageSD_3_5Node", display_name="Stability AI Stable Diffusion 3.5 Image", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleConservativeNode", display_name="Stability AI Upscale Conservative", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleCreativeNode", display_name="Stability AI Upscale Creative", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleFastNode", display_name="Stability AI Upscale Fast", - category="image/partner/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityTextToAudio", display_name="Stability AI Text To Audio", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", essentials_category="Audio", description=cleandoc(cls.__doc__ or ""), inputs=[ @@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioToAudio", display_name="Stability AI Audio To Audio", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( @@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioInpaint", display_name="Stability AI Audio Inpaint", - category="audio/partner/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index d0906ee44..f7ef4cbf6 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazImageEnhance", display_name="Topaz Image Enhance", - category="image/partner/Topaz", + category="partner/image/Topaz", description="Industry-standard upscaling and image enhancement.", inputs=[ IO.Combo.Input("model", options=["Reimagine"]), @@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhance", display_name="Topaz Video Enhance (Legacy)", - category="video/partner/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), @@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhanceV2", display_name="Topaz Video Enhance", - category="video/partner/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 4820e26c1..a3f2cb053 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -83,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextToModelNode", display_name="Tripo: Text to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.String.Input("prompt", multiline=True), IO.String.Input("negative_prompt", multiline=True, optional=True), @@ -210,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoImageToModelNode", display_name="Tripo: Image to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Combo.Input( @@ -358,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoMultiviewToModelNode", display_name="Tripo: Multiview to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Image.Input("image_left", optional=True), @@ -518,7 +518,7 @@ class TripoTextureNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextureNode", display_name="Tripo: Texture model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Boolean.Input("texture", default=True, optional=True), @@ -595,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode): return IO.Schema( node_id="TripoRefineNode", display_name="Tripo: Refine Draft model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Refine a draft model created by v1.4 Tripo models only.", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), @@ -635,7 +635,7 @@ class TripoRigNode(IO.ComfyNode): return IO.Schema( node_id="TripoRigNode", display_name="Tripo: Rig model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -672,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode): return IO.Schema( node_id="TripoRetargetNode", display_name="Tripo: Retarget rigged model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), IO.Combo.Input( @@ -737,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode): return IO.Schema( node_id="TripoConversionNode", display_name="Tripo: Convert model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), @@ -1051,7 +1051,7 @@ class TripoP1TextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1TextToModelNode", display_name="Tripo P1: Text to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."), @@ -1122,7 +1122,7 @@ class TripoP1ImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1ImageToModelNode", display_name="Tripo P1: Image to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.", inputs=[ IO.Image.Input("image"), @@ -1202,7 +1202,7 @@ class TripoP1MultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoP1MultiviewToModelNode", display_name="Tripo P1: Multiview to Model", - category="3d/partner/Tripo", + category="partner/3d/Tripo", description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. " "Front is required; any combination of the other three may be omitted.", inputs=[ diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 068862397..ed34e928b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="VeoVideoGenerationNode", display_name="Google Veo 2 Video Generation", - category="video/partner/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 2 API", inputs=[ IO.String.Input( @@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="Veo3VideoGenerationNode", display_name="Google Veo 3 Video Generation", - category="video/partner/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 3 API", inputs=[ IO.String.Input( @@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="Veo3FirstLastFrameNode", display_name="Google Veo 3 First-Last-Frame to Video", - category="video/partner/Veo", + category="partner/video/Veo", description="Generate video using prompt and first and last frames.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 16f6113de..8c5a43f5b 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduImageToVideoNode", display_name="Vidu Image To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from image and optional prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from multiple images and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduStartEndToVideoNode", display_name="Vidu Start End To Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2TextToVideoNode", display_name="Vidu2 Text-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ImageToVideoNode", display_name="Vidu2 Image-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ReferenceVideoNode", display_name="Vidu2 Reference-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from multiple reference images and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2StartEndToVideoNode", display_name="Vidu2 Start/End Frame-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduExtendVideoNode", display_name="Vidu Video Extension", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Extend an existing video by generating additional frames.", inputs=[ IO.DynamicCombo.Input( @@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduMultiFrameVideoNode", display_name="Vidu Multi-Frame Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video with multiple keyframe transitions.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), @@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3TextToVideoNode", display_name="Vidu Q3 Text-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3ImageToVideoNode", display_name="Vidu Q3 Image-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3StartEndToVideoNode", display_name="Vidu Q3 Start/End Frame-to-Video Generation", - category="video/partner/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.DynamicCombo.Input( diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index a235dc387..b7b97d70f 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToImageApi", display_name="Wan Text to Image", - category="image/partner/Wan", + category="partner/image/Wan", description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( @@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToImageApi", display_name="Wan Image to Image", - category="image/partner/Wan", + category="partner/image/Wan", description="Generates an image from one or two input images and a text prompt. " "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ @@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToVideoApi", display_name="Wan Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( @@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToVideoApi", display_name="Wan Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( @@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanReferenceVideoApi", display_name="Wan Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Use the character and voice from input videos, combined with a prompt, " "to generate a new video that maintains character consistency.", inputs=[ @@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2TextToVideoApi", display_name="Wan 2.7 Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the Wan 2.7 model.", inputs=[ IO.DynamicCombo.Input( @@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ImageToVideoApi", display_name="Wan 2.7 Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image, with optional last-frame image and audio.", inputs=[ IO.DynamicCombo.Input( @@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoContinuationApi", display_name="Wan 2.7 Video Continuation", - category="video/partner/Wan", + category="partner/video/Wan", description="Continue a video from where it left off, with optional last-frame control.", inputs=[ IO.DynamicCombo.Input( @@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoEditApi", display_name="Wan 2.7 Video Edit", - category="video/partner/Wan", + category="partner/video/Wan", description="Edit a video using text instructions, reference images, or style transfer.", inputs=[ IO.DynamicCombo.Input( @@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ReferenceVideoApi", display_name="Wan 2.7 Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials. " "Supports single-character performances and multi-character interactions.", inputs=[ @@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseTextToVideoApi", display_name="HappyHorse Text to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseImageToVideoApi", display_name="HappyHorse Image to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseVideoEditApi", display_name="HappyHorse Video Edit", - category="video/partner/Wan", + category="partner/video/Wan", description="Edit a video using text instructions or reference images with the HappyHorse model. " "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", inputs=[ @@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseReferenceVideoApi", display_name="HappyHorse Reference to Video", - category="video/partner/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials with the HappyHorse " "model. Supports single-character performances and multi-character interactions.", inputs=[ diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index a250015c3..5839f9d37 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedFlashVSRNode", display_name="FlashVSR Video Upscale", - category="video/partner/WaveSpeed", + category="partner/video/WaveSpeed", description="Fast, high-quality video upscaler that " "boosts resolution and restores clarity for low-resolution or blurry footage.", inputs=[ @@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedImageUpscaleNode", display_name="WaveSpeed Image Upscale", - category="image/partner/WaveSpeed", + category="partner/image/WaveSpeed", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", inputs=[ IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), From c96fcddb8100d7d5358c33f5fb4fab33cb6a2da0 Mon Sep 17 00:00:00 2001 From: person4268 <28717044+person4268@users.noreply.github.com> Date: Tue, 2 Jun 2026 01:07:48 -0400 Subject: [PATCH 28/32] Radiance: support variant with nonzero txt_ids (#14206) --- comfy/ldm/chroma_radiance/model.py | 8 ++++++++ comfy/model_detection.py | 4 ++++ comfy_extras/nodes_chroma_radiance.py | 10 ++++++++++ 3 files changed, 22 insertions(+) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 4fb56165e..86af98d36 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams): # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] use_x0: bool + # Use sequential txt_ids instead of zeros + use_sequential_txt_ids: bool class ChromaRadiance(Chroma): """ @@ -162,6 +164,9 @@ class ChromaRadiance(Chroma): if params.use_x0: self.register_buffer("__x0__", torch.tensor([])) + if params.use_sequential_txt_ids: + self.register_buffer("__sequential__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -313,6 +318,9 @@ class ChromaRadiance(Chroma): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + # Radiance after 2026-05-22 uses sequential txt_ids instead of zeros + if params.use_sequential_txt_ids: + txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1) img_out = self.forward_orig( img, diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 73354b0d2..24e742a7f 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_x0"] = True else: dit_config["use_x0"] = False + if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids + dit_config["use_sequential_txt_ids"] = True + else: + dit_config["use_sequential_txt_ids"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index ca427e5cb..a4f673001 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode): tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", advanced=True, ), + io.Boolean.Input( + id="force_sequential_txt_ids", + default=False, + tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.", + advanced=True, + ), ], outputs=[io.Model.Output()], ) @@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode): start_sigma: float, end_sigma: float, nerf_tile_size: int, + force_sequential_txt_ids: bool, ) -> io.NodeOutput: radiance_options = {} if nerf_tile_size >= 0: radiance_options["nerf_tile_size"] = nerf_tile_size + if force_sequential_txt_ids: + radiance_options["use_sequential_txt_ids"] = True + if not radiance_options: return io.NodeOutput(model) From e9207aa7ccb6f06fce42ef0e42e0d7450bef3b3f Mon Sep 17 00:00:00 2001 From: Quasar of Mikus <159663231+quasar-of-mikus@users.noreply.github.com> Date: Tue, 2 Jun 2026 13:05:24 -0400 Subject: [PATCH 29/32] fix (MultiGPU): prevent freeze on manual abort when using MultiGPU CFG Split (#14235) * fix (MultiGPU): prevent freeze on manual abort when using MultiGPU CFG Split Problem: Upon manual abort application hangs indefinitely. `InterruptProcessingException` inherits from `BaseException` and bypasses MultiGPU's worker error handling block so thread dies silently, leaving the main thread waiting forever for `result_q.get()` Fix: Catch `comfy.model_management.InterruptProcessingException` instead of `Exception` so it's caught and passed back via `result_q` to unblock the main thread when manual abort signal fires. * oops --- comfy/multigpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index bb9d334d3..2b6d8260d 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -54,6 +54,8 @@ class MultiGPUThreadPool: try: result = fn(*args, **kwargs) result_q.put((result, None)) + except comfy.model_management.InterruptProcessingException as e: + result_q.put((None, e)) except Exception as e: result_q.put((None, e)) From dc10c0133ebda1d6438c65fb6be9cd5eb20b4434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Tue, 2 Jun 2026 22:40:49 +0300 Subject: [PATCH 30/32] PiD: Add SDXL and QwenImage (#14240) --- comfy_extras/nodes_pid.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py index 811b9ae8e..71855254e 100644 --- a/comfy_extras/nodes_pid.py +++ b/comfy_extras/nodes_pid.py @@ -21,8 +21,8 @@ class PiDConditioning(io.ComfyNode): inputs=[ io.Conditioning.Input("positive"), io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), - io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", - tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Combo.Input("latent_format", options=["flux", "sd3", "sdxl", "qwenimage"], default="flux", + tooltip="Flux1 (16-ch) and Flux2 (128-ch) latents are auto-detected from channel dim under 'flux'. For SD3 (16-ch), SDXL (4-ch), or QwenImage (16-ch), select manually."), io.Float.Input( "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", @@ -36,9 +36,17 @@ class PiDConditioning(io.ComfyNode): samples = latent["samples"] if latent_format == "flux": fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux - else: + elif latent_format == "sd3": fmt_cls = comfy.latent_formats.SD3 + elif latent_format == "sdxl": + fmt_cls = comfy.latent_formats.SDXL + elif latent_format == "qwenimage": + fmt_cls = comfy.latent_formats.Wan21 + else: + raise ValueError(f"Unknown latent_format: {latent_format}") lq_latent = fmt_cls().process_in(samples) + if lq_latent.ndim == 5: + lq_latent = lq_latent[:, :, 0] sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) return io.NodeOutput(node_helpers.conditioning_set_values( positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, From d4c7ebff9c318698082fed85a0ebf3bdc1fe3d2a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:52:41 -0700 Subject: [PATCH 31/32] Remove old useless no comfy kitchen fallback. (#14245) * Remove old fallback used when no comfy kitchen. * Remove unused logging import --- comfy/ldm/flux/math.py | 32 +++++++++++++------------------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6d0aed827..891dea7dd 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,7 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -import logging +import comfy.quant_ops def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -44,21 +44,15 @@ def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) -try: - import comfy.quant_ops - q_apply_rope = comfy.quant_ops.ck.apply_rope - q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 - def apply_rope(xq, xk, freqs_cis): - if comfy.model_management.in_training: - return _apply_rope(xq, xk, freqs_cis) - else: - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) - def apply_rope1(x, freqs_cis): - if comfy.model_management.in_training: - return _apply_rope1(x, freqs_cis) - else: - return q_apply_rope1(x, freqs_cis) -except: - logging.warning("No comfy kitchen, using old apply_rope functions.") - apply_rope = _apply_rope - apply_rope1 = _apply_rope1 +def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return comfy.quant_ops.ck.apply_rope(xq, xk, freqs_cis) + + +def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return comfy.quant_ops.ck.apply_rope1(x, freqs_cis) From bd7da053aeaa3424828f7a0fb6ebeffeec9f5876 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Jun 2026 11:57:16 +1000 Subject: [PATCH 32/32] comfy-aimdo: 0.4.8 (#14244) Aimdo 0.4.8 fixes a crash in multi-gpu due to contention on the singleton bounce buffer. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index b09d31a8b..7dff9e3c3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=16.0.0 comfy-kitchen==0.2.10 -comfy-aimdo==0.4.7 +comfy-aimdo==0.4.8 requests simpleeval>=1.0.0 blake3