Merge branch 'master' into feat/cut-release-workflow
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled

This commit is contained in:
Jedrzej Kosinski 2026-05-18 09:43:02 -07:00 committed by GitHub
commit 91fe94bb78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 26563 additions and 964 deletions

View File

@ -89,3 +89,12 @@ rules:
then:
field: description
function: truthy
overrides:
# /ws uses HTTP 101 (Switching Protocols) — a legitimate response for a
# WebSocket upgrade, but not a 2xx, so operation-success-response fires
# as a false positive. OpenAPI 3.x has no native WebSocket support.
- files:
- "openapi.yaml#/paths/~1ws"
rules:
operation-success-response: off

View File

@ -38,7 +38,7 @@
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- It is available on Windows, Linux, and macOS, locally with our [desktop application](https://www.comfy.org/download), our [portable install](#installing) or on our [cloud](https://www.comfy.org/cloud).
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
@ -429,6 +429,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
See also: [https://www.comfy.org/](https://www.comfy.org/)
> _psst — we're hiring!_ Help build ComfyUI: [comfy.org/careers](https://www.comfy.org/careers)
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.

44
SECURITY.md Normal file
View File

@ -0,0 +1,44 @@
# Security Policy
## Scope
ComfyUI is designed to run locally. By default, the server binds to `127.0.0.1`, meaning only the user's own machine can reach it. Our threat model assumes:
- The user installed ComfyUI through a supported channel: the desktop application, the portable build, or a manual install following the README.
- The user has not installed untrusted custom nodes. Custom nodes are arbitrary Python code and are trusted as much as any other software the user chooses to install.
- Anyone with access to the ComfyUI URL is trusted (a direct consequence of the localhost-only default).
- PyTorch and other dependencies are at the versions we ship or recommend in the README.
A report is in scope only if it affects a user operating within this threat model.
## What We Consider a Vulnerability
We want to hear about issues where a **reasonable user** — someone who does not install random untrusted nodes and who reads UI prompts and warnings before clicking through them — can be harmed by ComfyUI itself.
The clearest example: a workflow file that such a user might plausibly load and run, using only built-in nodes, that results in **untrusted code execution, arbitrary file read/write outside expected directories, or credential/data exfiltration**.
When submitting a report, please include a clear description of *why this is a problem for a typical local ComfyUI user*. Reports without this context are difficult to act on.
## What We Do Not Consider a Security Vulnerability
Please report the following through our regular [GitHub issues](https://github.com/comfyanonymous/ComfyUI/issues) instead. Filing them as security reports will likely cause them to be deprioritized or closed.
- **Issues requiring `--listen` or any non-default network exposure.** ComfyUI binds to localhost by default. If a remote attacker needs to reach the server for the attack to work, the user has chosen to expose it and is responsible for securing that deployment (firewall, reverse proxy, authentication, etc.). These are bugs, not vulnerabilities.
- **`torch.load` and related deserialization issues in old PyTorch versions.** These are upstream PyTorch issues. Our distributions ship with — and our documentation recommends — recent PyTorch versions where these are addressed.
- **Vulnerabilities that depend on outdated library versions** that we neither ship nor recommend (e.g., requiring PyTorch 2.6 or older).
- **Issues that require a specific custom node to be installed.** Custom nodes are third-party code. Report these to the maintainer of that node.
- **Crashes, hangs, or resource exhaustion from a loaded workflow.** Annoying, but not a security issue in our model. File a regular bug.
- **Social-engineering scenarios** where the user is expected to ignore an explicit UI warning or prompt.
## Reporting
If you believe you have found an issue that falls within the scope above, please report it privately via GitHub's [Report a vulnerability](https://github.com/comfyanonymous/ComfyUI/security/advisories/new) feature rather than opening a public issue.
Please include:
1. A description of the vulnerability and the affected component.
2. Reproduction steps, ideally with a minimal workflow file or proof-of-concept.
3. The ComfyUI version, install method (desktop / portable / manual), and OS.
4. An explanation of how this affects a typical local user as described in the threat model.
We will acknowledge valid reports and coordinate a fix and disclosure timeline with you.

View File

@ -38,40 +38,54 @@ def is_valid_version(version: str) -> bool:
pattern = r"^(\d+)\.(\d+)\.(\d+)$"
return bool(re.match(pattern, version))
def get_installed_frontend_version():
"""Get the currently installed frontend package version."""
frontend_version_str = version("comfyui-frontend-package")
return frontend_version_str
def get_required_frontend_version():
return get_required_packages_versions().get("comfyui-frontend-package", None)
def check_frontend_version():
"""Check if the frontend version is up to date."""
COMFY_PACKAGE_VERSIONS = []
def get_comfy_package_versions():
"""List installed/required versions for every comfy* package in requirements.txt."""
if COMFY_PACKAGE_VERSIONS:
return COMFY_PACKAGE_VERSIONS.copy()
out = COMFY_PACKAGE_VERSIONS
for name, required in (get_required_packages_versions() or {}).items():
if not name.startswith("comfy"):
continue
try:
installed = version(name)
except Exception:
installed = None
out.append({"name": name, "installed": installed, "required": required})
return out.copy()
try:
frontend_version_str = get_installed_frontend_version()
frontend_version = parse_version(frontend_version_str)
required_frontend_str = get_required_frontend_version()
required_frontend = parse_version(required_frontend_str)
if frontend_version < required_frontend:
def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440
for pkg in get_comfy_package_versions():
installed_str = pkg["installed"]
required_str = pkg["required"]
if not installed_str or not required_str:
continue
try:
outdated = parse_pep440(installed_str) < parse_pep440(required_str)
except InvalidVersion as e:
logging.error(f"Failed to check {pkg['name']} version: {e}")
continue
if outdated:
app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}.
{frontend_install_warning_message()}
{get_missing_requirements_message()}
________________________________________________________________________
""".strip()
)
else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e:
logging.error(f"Failed to check frontend version: {e}")
logging.info("{} version: {}".format(pkg["name"], installed_str))
REQUEST_TIMEOUT = 10 # seconds
@ -201,6 +215,11 @@ class FrontendManager:
def get_required_templates_version(cls) -> str:
return get_required_packages_versions().get("comfyui-workflow-templates", None)
@classmethod
def get_comfy_package_versions(cls):
"""List installed/required versions for every comfy* package in requirements.txt."""
return get_comfy_package_versions()
@classmethod
def default_frontend_path(cls) -> str:
try:
@ -341,7 +360,7 @@ comfyui-workflow-templates is not installed.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
check_frontend_version()
check_comfy_packages_versions()
return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string)
@ -403,7 +422,7 @@ comfyui-workflow-templates is not installed.
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
check_frontend_version()
check_comfy_packages_versions()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):

File diff suppressed because it is too large Load Diff

View File

@ -4234,7 +4234,7 @@
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Depth to video",
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
"description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio."
},
{
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,858 @@
{
"revision": 0,
"last_node_id": 16,
"last_link_id": 0,
"nodes": [
{
"id": 16,
"type": "022693be-2baa-4009-870a-28921508a7ef",
"pos": [
-2990,
-3240
],
"size": [
410,
200
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": null
},
{
"label": "multiplier",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": null
},
{
"label": "enable_fps_multiplier",
"name": "value_1",
"type": "BOOLEAN",
"widget": {
"name": "value_1"
},
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"label": "VIDEO",
"name": "VIDEO_1",
"type": "VIDEO",
"links": []
},
{
"name": "IMAGE",
"type": "IMAGE",
"links": null
}
],
"properties": {
"proxyWidgets": [
[
"9",
"value"
],
[
"13",
"value"
],
[
"1",
"model_name"
]
],
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [],
"title": "Frame Interpolation"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "022693be-2baa-4009-870a-28921508a7ef",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 17,
"lastLinkId": 28,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Frame Interpolation",
"inputNode": {
"id": -10,
"bounding": [
-2810,
-3070,
159.7421875,
120
]
},
"outputNode": {
"id": -20,
"bounding": [
-1270,
-3075,
120,
80
]
},
"inputs": [
{
"id": "05e31c51-dcb6-4a1e-9651-1b9ad4f7a287",
"name": "video",
"type": "VIDEO",
"linkIds": [
2
],
"localized_name": "video",
"pos": [
-2670.2578125,
-3050
]
},
{
"id": "feecb409-7d1c-4a99-9c63-50c5fecdd3c9",
"name": "value",
"type": "INT",
"linkIds": [
22
],
"label": "multiplier",
"pos": [
-2670.2578125,
-3030
]
},
{
"id": "0b8a861b-b581-4068-9e8c-f8d15daf1ca6",
"name": "value_1",
"type": "BOOLEAN",
"linkIds": [
23
],
"label": "enable_fps_multiplier",
"pos": [
-2670.2578125,
-3010
]
},
{
"id": "a22b101e-8773-4e17-a297-7ee3aae09162",
"name": "model_name",
"type": "COMBO",
"linkIds": [
24
],
"pos": [
-2670.2578125,
-2990
]
}
],
"outputs": [
{
"id": "ef2ada05-d5aa-492a-9394-6c3e71e39ebb",
"name": "VIDEO_1",
"type": "VIDEO",
"linkIds": [
26
],
"label": "VIDEO",
"pos": [
-1250,
-3055
]
},
{
"id": "5aacc622-2a07-4983-b31c-e04461f7f953",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
28
],
"pos": [
-1250,
-3035
]
}
],
"widgets": [],
"nodes": [
{
"id": 1,
"type": "FrameInterpolationModelLoader",
"pos": [
-2510,
-3370
],
"size": [
370,
90
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 24
}
],
"outputs": [
{
"localized_name": "INTERP_MODEL",
"name": "INTERP_MODEL",
"type": "INTERP_MODEL",
"links": [
1
]
}
],
"properties": {
"Node name for S&R": "FrameInterpolationModelLoader",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3",
"models": [
{
"name": "film_net_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/frame_interpolation/resolve/main/frame_interpolation/film_net_fp16.safetensors",
"directory": "frame_interpolation"
}
]
},
"widgets_values": [
"film_net_fp16.safetensors"
]
},
{
"id": 2,
"type": "FrameInterpolate",
"pos": [
-2040,
-3370
],
"size": [
270,
110
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "interp_model",
"name": "interp_model",
"type": "INTERP_MODEL",
"link": 1
},
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"link": 3
},
{
"localized_name": "multiplier",
"name": "multiplier",
"type": "INT",
"widget": {
"name": "multiplier"
},
"link": 8
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
4,
28
]
}
],
"properties": {
"Node name for S&R": "FrameInterpolate",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
2
]
},
{
"id": 5,
"type": "CreateVideo",
"pos": [
-1600,
-3370
],
"size": [
270,
110
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"link": 4
},
{
"localized_name": "audio",
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": 5
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"widget": {
"name": "fps"
},
"link": 12
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"links": [
26
]
}
],
"properties": {
"Node name for S&R": "CreateVideo",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
30
]
},
{
"id": 9,
"type": "PrimitiveInt",
"pos": [
-2500,
-2970
],
"size": [
270,
90
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 22
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
8,
19
]
}
],
"title": "Int (Multiplier)",
"properties": {
"Node name for S&R": "PrimitiveInt",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
2,
"fixed"
]
},
{
"id": 10,
"type": "ComfySwitchNode",
"pos": [
-1610,
-3120
],
"size": [
270,
130
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"localized_name": "on_false",
"name": "on_false",
"type": "*",
"link": 11
},
{
"localized_name": "on_true",
"name": "on_true",
"type": "*",
"link": 13
},
{
"localized_name": "switch",
"name": "switch",
"type": "BOOLEAN",
"widget": {
"name": "switch"
},
"link": 15
}
],
"outputs": [
{
"localized_name": "output",
"name": "output",
"type": "*",
"links": [
12
]
}
],
"properties": {
"Node name for S&R": "ComfySwitchNode",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
true
]
},
{
"id": 13,
"type": "PrimitiveBoolean",
"pos": [
-2500,
-2770
],
"size": [
310,
90
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "BOOLEAN",
"widget": {
"name": "value"
},
"link": 23
}
],
"outputs": [
{
"localized_name": "BOOLEAN",
"name": "BOOLEAN",
"type": "BOOLEAN",
"links": [
15
]
}
],
"title": "Boolean (Apply multiplier to FPS?)",
"properties": {
"Node name for S&R": "PrimitiveBoolean",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
true
]
},
{
"id": 3,
"type": "GetVideoComponents",
"pos": [
-2500,
-3170
],
"size": [
230,
100
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 2
}
],
"outputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"links": [
3
]
},
{
"localized_name": "audio",
"name": "audio",
"type": "AUDIO",
"links": [
5
]
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"links": [
11,
18
]
}
],
"properties": {
"Node name for S&R": "GetVideoComponents",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
}
},
{
"id": 11,
"type": "ComfyMathExpression",
"pos": [
-2090,
-3070
],
"size": [
400,
210
],
"flags": {
"collapsed": false
},
"order": 6,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT",
"link": 18
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT",
"link": 19
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": [
13
]
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "ComfyMathExpression",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"cnr_id": "comfy-core",
"ver": "0.19.3"
},
"widgets_values": [
"min(abs(b), 16) * a"
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 1,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "INTERP_MODEL"
},
{
"id": 3,
"origin_id": 3,
"origin_slot": 0,
"target_id": 2,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 8,
"origin_id": 9,
"origin_slot": 0,
"target_id": 2,
"target_slot": 2,
"type": "INT"
},
{
"id": 4,
"origin_id": 2,
"origin_slot": 0,
"target_id": 5,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 5,
"origin_id": 3,
"origin_slot": 1,
"target_id": 5,
"target_slot": 1,
"type": "AUDIO"
},
{
"id": 12,
"origin_id": 10,
"origin_slot": 0,
"target_id": 5,
"target_slot": 2,
"type": "FLOAT"
},
{
"id": 11,
"origin_id": 3,
"origin_slot": 2,
"target_id": 10,
"target_slot": 0,
"type": "FLOAT"
},
{
"id": 13,
"origin_id": 11,
"origin_slot": 0,
"target_id": 10,
"target_slot": 1,
"type": "FLOAT"
},
{
"id": 15,
"origin_id": 13,
"origin_slot": 0,
"target_id": 10,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 18,
"origin_id": 3,
"origin_slot": 2,
"target_id": 11,
"target_slot": 0,
"type": "FLOAT"
},
{
"id": 19,
"origin_id": 9,
"origin_slot": 0,
"target_id": 11,
"target_slot": 1,
"type": "INT"
},
{
"id": 2,
"origin_id": -10,
"origin_slot": 0,
"target_id": 3,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 22,
"origin_id": -10,
"origin_slot": 1,
"target_id": 9,
"target_slot": 0,
"type": "INT"
},
{
"id": 23,
"origin_id": -10,
"origin_slot": 2,
"target_id": 13,
"target_slot": 0,
"type": "BOOLEAN"
},
{
"id": 24,
"origin_id": -10,
"origin_slot": 3,
"target_id": 1,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 26,
"origin_id": 5,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 28,
"origin_id": 2,
"origin_slot": 0,
"target_id": -20,
"target_slot": 1,
"type": "IMAGE"
}
],
"extra": {},
"category": "Video Tools",
"description": "Increases video frame rate by synthesizing intermediate frames with a frame interpolation model."
}
]
},
"extra": {}
}

View File

@ -0,0 +1,485 @@
{
"revision": 0,
"last_node_id": 98,
"last_link_id": 0,
"nodes": [
{
"id": 98,
"type": "dca6e78d-fb06-421e-97f7-6ce17a665260",
"pos": [
-410,
-2230
],
"size": [
270,
104
],
"flags": {},
"order": 7,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": null
},
{
"label": "frame_index",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": null
}
],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": []
}
],
"title": "Get Any Video Frame",
"properties": {
"proxyWidgets": [
[
"100",
"value"
]
]
},
"widgets_values": []
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "dca6e78d-fb06-421e-97f7-6ce17a665260",
"version": 1,
"state": {
"lastGroupId": 1,
"lastNodeId": 136,
"lastLinkId": 302,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Get Any Video Frame",
"inputNode": {
"id": -10,
"bounding": [
380,
-57,
120,
80
]
},
"outputNode": {
"id": -20,
"bounding": [
1460,
-57,
120,
60
]
},
"inputs": [
{
"id": "2ceec378-8dcf-4340-8570-155967f59a93",
"name": "video",
"type": "VIDEO",
"linkIds": [
4
],
"pos": [
480,
-37
]
},
{
"id": "819955f6-c686-4896-8032-ff2d0059109a",
"name": "value",
"type": "INT",
"linkIds": [
283
],
"label": "frame_index",
"pos": [
480,
-17
]
}
],
"outputs": [
{
"id": "1ab0684d-6a44-45b6-8aa4-a0b971a1d41e",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
5
],
"pos": [
1480,
-37
]
}
],
"widgets": [],
"nodes": [
{
"id": 1,
"type": "GetVideoComponents",
"pos": [
560,
-150
],
"size": [
230,
120
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 4
}
],
"outputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"links": [
1,
2
]
},
{
"localized_name": "audio",
"name": "audio",
"type": "AUDIO",
"links": null
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetVideoComponents"
}
},
{
"id": 2,
"type": "GetImageSize",
"pos": [
560,
50
],
"size": [
230,
120
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 1
}
],
"outputs": [
{
"localized_name": "width",
"name": "width",
"type": "INT",
"links": null
},
{
"localized_name": "height",
"name": "height",
"type": "INT",
"links": null
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"links": [
285
]
}
],
"properties": {
"Node name for S&R": "GetImageSize"
}
},
{
"id": 3,
"type": "ImageFromBatch",
"pos": [
1130,
-150
],
"size": [
270,
140
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 2
},
{
"localized_name": "batch_index",
"name": "batch_index",
"type": "INT",
"widget": {
"name": "batch_index"
},
"link": 286
},
{
"localized_name": "length",
"name": "length",
"type": "INT",
"widget": {
"name": "length"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
5
]
}
],
"properties": {
"Node name for S&R": "ImageFromBatch"
},
"widgets_values": [
0,
1
]
},
{
"id": 99,
"type": "ComfyMathExpression",
"pos": [
910,
100
],
"size": [
400,
200
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT",
"link": 284
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT",
"link": 285
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
286
]
}
],
"properties": {
"Node name for S&R": "ComfyMathExpression"
},
"widgets_values": [
"min(max(int(a if a >= 0 else b + a), 0), b - 1)"
]
},
{
"id": 100,
"type": "PrimitiveInt",
"pos": [
560,
250
],
"size": [
270,
110
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 283
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
284
]
}
],
"properties": {
"Node name for S&R": "PrimitiveInt"
},
"widgets_values": [
0,
"fixed"
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 1,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 2,
"origin_id": 1,
"origin_slot": 0,
"target_id": 3,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 4,
"origin_id": -10,
"origin_slot": 0,
"target_id": 1,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 5,
"origin_id": 3,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 283,
"origin_id": -10,
"origin_slot": 1,
"target_id": 100,
"target_slot": 0,
"type": "INT"
},
{
"id": 284,
"origin_id": 100,
"origin_slot": 0,
"target_id": 99,
"target_slot": 0,
"type": "INT"
},
{
"id": 285,
"origin_id": 2,
"origin_slot": 2,
"target_id": 99,
"target_slot": 1,
"type": "INT"
},
{
"id": 286,
"origin_id": 99,
"origin_slot": 1,
"target_id": 3,
"target_slot": 1,
"type": "INT"
}
],
"extra": {},
"category": "Video Tools",
"description": "Extracts one image frame from a video at a chosen index, with optional trim and FPS control."
}
]
},
"extra": {
"ds": {
"scale": 1.197015527856339,
"offset": [
-168.76833554248222,
540.6638955283997
]
},
"frontendVersion": "1.42.8"
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,714 @@
{
"revision": 0,
"last_node_id": 99,
"last_link_id": 0,
"nodes": [
{
"id": 99,
"type": "6e7ab3ea-96aa-470f-9b94-3d9d0e01f481",
"pos": [
-1630,
-3270
],
"size": [
290,
370
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"label": "image",
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": null
},
{
"label": "object",
"name": "text",
"type": "STRING",
"widget": {
"name": "text"
},
"link": null
},
{
"name": "bboxes",
"type": "BOUNDING_BOX",
"link": null
},
{
"name": "positive_coords",
"type": "STRING",
"link": null
},
{
"name": "negative_coords",
"type": "STRING",
"link": null
},
{
"name": "threshold",
"type": "FLOAT",
"widget": {
"name": "threshold"
},
"link": null
},
{
"name": "refine_iterations",
"type": "INT",
"widget": {
"name": "refine_iterations"
},
"link": null
},
{
"name": "individual_masks",
"type": "BOOLEAN",
"widget": {
"name": "individual_masks"
},
"link": null
},
{
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "masks",
"name": "masks",
"type": "MASK",
"links": []
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"78",
"text"
],
[
"75",
"threshold"
],
[
"75",
"refine_iterations"
],
[
"75",
"individual_masks"
],
[
"77",
"ckpt_name"
]
],
"ue_properties": {
"widget_ue_connectable": {
"text": true
},
"version": "7.7",
"input_ue_unconnectable": {}
},
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [],
"title": "Image Segmentation (SAM3)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "6e7ab3ea-96aa-470f-9b94-3d9d0e01f481",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 113,
"lastLinkId": 283,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image Segmentation (SAM3)",
"inputNode": {
"id": -10,
"bounding": [
-2260,
-3450,
136.369140625,
220
]
},
"outputNode": {
"id": -20,
"bounding": [
-1130,
-3305,
120,
80
]
},
"inputs": [
{
"id": "a6e75fa2-162a-4af0-a2fd-1e9c899a5ab6",
"name": "image",
"type": "IMAGE",
"linkIds": [
264
],
"localized_name": "image",
"label": "image",
"pos": [
-2143.630859375,
-3430
]
},
{
"id": "3cefd304-7631-4ff6-a5a0-5a0ffb120745",
"name": "text",
"type": "STRING",
"linkIds": [
265
],
"label": "object",
"pos": [
-2143.630859375,
-3410
]
},
{
"id": "1aec91c5-d8d2-441c-928c-49c14e7e80ed",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
266
],
"pos": [
-2143.630859375,
-3390
]
},
{
"id": "1ec7ce1a-8257-4719-8a81-60ebc8a98899",
"name": "positive_coords",
"type": "STRING",
"linkIds": [
267
],
"pos": [
-2143.630859375,
-3370
]
},
{
"id": "c65f8b87-9bd7-48be-9fc2-823431e95019",
"name": "negative_coords",
"type": "STRING",
"linkIds": [
268
],
"pos": [
-2143.630859375,
-3350
]
},
{
"id": "bb4ba35a-ccfe-4c37-98e5-d9b0d69585fb",
"name": "threshold",
"type": "FLOAT",
"linkIds": [
269
],
"pos": [
-2143.630859375,
-3330
]
},
{
"id": "b1439668-b050-490b-a5dc-fc4052c55666",
"name": "refine_iterations",
"type": "INT",
"linkIds": [
270
],
"pos": [
-2143.630859375,
-3310
]
},
{
"id": "86e239e5-c098-4302-b54d-d42a38bc0f89",
"name": "individual_masks",
"type": "BOOLEAN",
"linkIds": [
271
],
"pos": [
-2143.630859375,
-3290
]
},
{
"id": "f9e0b9d4-b2f1-4907-a4a5-305656576706",
"name": "ckpt_name",
"type": "COMBO",
"linkIds": [
272
],
"pos": [
-2143.630859375,
-3270
]
}
],
"outputs": [
{
"id": "ff50da09-1e59-4a58-9b7f-be1a00aa5913",
"name": "masks",
"type": "MASK",
"linkIds": [
231
],
"localized_name": "masks",
"pos": [
-1110,
-3285
]
},
{
"id": "8f622e40-8528-4078-b7d3-147e9f872194",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
232
],
"localized_name": "bboxes",
"pos": [
-1110,
-3265
]
}
],
"widgets": [],
"nodes": [
{
"id": 75,
"type": "SAM3_Detect",
"pos": [
-1470,
-3460
],
"size": [
270,
260
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"label": "model",
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 237
},
{
"label": "image",
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 264
},
{
"label": "conditioning",
"localized_name": "conditioning",
"name": "conditioning",
"shape": 7,
"type": "CONDITIONING",
"link": 200
},
{
"label": "bboxes",
"localized_name": "bboxes",
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": 266
},
{
"label": "positive_coords",
"localized_name": "positive_coords",
"name": "positive_coords",
"shape": 7,
"type": "STRING",
"link": 267
},
{
"label": "negative_coords",
"localized_name": "negative_coords",
"name": "negative_coords",
"shape": 7,
"type": "STRING",
"link": 268
},
{
"localized_name": "threshold",
"name": "threshold",
"type": "FLOAT",
"widget": {
"name": "threshold"
},
"link": 269
},
{
"localized_name": "refine_iterations",
"name": "refine_iterations",
"type": "INT",
"widget": {
"name": "refine_iterations"
},
"link": 270
},
{
"localized_name": "individual_masks",
"name": "individual_masks",
"type": "BOOLEAN",
"widget": {
"name": "individual_masks"
},
"link": 271
}
],
"outputs": [
{
"localized_name": "masks",
"name": "masks",
"type": "MASK",
"links": [
231
]
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": [
232
]
}
],
"properties": {
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
},
"cnr_id": "comfy-core",
"ver": "0.19.3",
"Node name for S&R": "SAM3_Detect",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
0.5,
2,
false
]
},
{
"id": 77,
"type": "CheckpointLoaderSimple",
"pos": [
-1970,
-3200
],
"size": [
330,
140
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "ckpt_name",
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": 272
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"links": [
237
]
},
{
"localized_name": "CLIP",
"name": "CLIP",
"type": "CLIP",
"links": [
240
]
},
{
"localized_name": "VAE",
"name": "VAE",
"type": "VAE",
"links": null
}
],
"properties": {
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
},
"cnr_id": "comfy-core",
"ver": "0.19.3",
"Node name for S&R": "CheckpointLoaderSimple",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"models": [
{
"name": "sam3.1_multiplex_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/sam3.1/resolve/main/checkpoints/sam3.1_multiplex_fp16.safetensors",
"directory": "checkpoints"
}
]
},
"widgets_values": [
"sam3.1_multiplex_fp16.safetensors"
]
},
{
"id": 78,
"type": "CLIPTextEncode",
"pos": [
-2000,
-3000
],
"size": [
400,
200
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "clip",
"name": "clip",
"type": "CLIP",
"link": 240
},
{
"localized_name": "text",
"name": "text",
"type": "STRING",
"widget": {
"name": "text"
},
"link": 265
}
],
"outputs": [
{
"localized_name": "CONDITIONING",
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
200
]
}
],
"properties": {
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
},
"cnr_id": "comfy-core",
"ver": "0.19.3",
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
""
]
}
],
"groups": [],
"links": [
{
"id": 237,
"origin_id": 77,
"origin_slot": 0,
"target_id": 75,
"target_slot": 0,
"type": "MODEL"
},
{
"id": 200,
"origin_id": 78,
"origin_slot": 0,
"target_id": 75,
"target_slot": 2,
"type": "CONDITIONING"
},
{
"id": 240,
"origin_id": 77,
"origin_slot": 1,
"target_id": 78,
"target_slot": 0,
"type": "CLIP"
},
{
"id": 231,
"origin_id": 75,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "MASK"
},
{
"id": 232,
"origin_id": 75,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "BOUNDING_BOX"
},
{
"id": 264,
"origin_id": -10,
"origin_slot": 0,
"target_id": 75,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 265,
"origin_id": -10,
"origin_slot": 1,
"target_id": 78,
"target_slot": 1,
"type": "STRING"
},
{
"id": 266,
"origin_id": -10,
"origin_slot": 2,
"target_id": 75,
"target_slot": 3,
"type": "BOUNDING_BOX"
},
{
"id": 267,
"origin_id": -10,
"origin_slot": 3,
"target_id": 75,
"target_slot": 4,
"type": "STRING"
},
{
"id": 268,
"origin_id": -10,
"origin_slot": 4,
"target_id": 75,
"target_slot": 5,
"type": "STRING"
},
{
"id": 269,
"origin_id": -10,
"origin_slot": 5,
"target_id": 75,
"target_slot": 6,
"type": "FLOAT"
},
{
"id": 270,
"origin_id": -10,
"origin_slot": 6,
"target_id": 75,
"target_slot": 7,
"type": "INT"
},
{
"id": 271,
"origin_id": -10,
"origin_slot": 7,
"target_id": 75,
"target_slot": 8,
"type": "BOOLEAN"
},
{
"id": 272,
"origin_id": -10,
"origin_slot": 8,
"target_id": 77,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Image Tools/Image Segmentation",
"description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes."
}
]
},
"extra": {
"ue_links": []
}
}

View File

@ -2028,7 +2028,7 @@
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Image to video",
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
"description": "Image-to-video with Wan 2.2 using a start image plus text prompt to extend motion from the still frame."
}
]
},

View File

@ -0,0 +1,397 @@
{
"revision": 0,
"last_node_id": 19,
"last_link_id": 0,
"nodes": [
{
"id": 19,
"type": "5b40ca21-ba1a-41d5-b403-4d2d7acdc195",
"pos": [
-6411.330578108367,
1940.2638932730042
],
"size": [
349.609375,
145.9375
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": null
},
{
"name": "bg_removal_name",
"type": "COMBO",
"widget": {
"name": "bg_removal_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
},
{
"name": "mask",
"type": "MASK",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"14",
"bg_removal_name"
]
]
},
"widgets_values": [],
"title": "Remove Background (BiRefNet)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "5b40ca21-ba1a-41d5-b403-4d2d7acdc195",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 21,
"lastLinkId": 16,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Remove Background (BiRefNet)",
"description": "Removes or replaces image backgrounds using BiRefNet segmentation and alpha compositing.",
"inputNode": {
"id": -10,
"bounding": [
-6728.534070722246,
1475.2619799128663,
150.9140625,
88
]
},
"outputNode": {
"id": -20,
"bounding": [
-6169.049695722246,
1475.2619799128663,
128,
88
]
},
"inputs": [
{
"id": "7bc321cd-df31-4c39-aaf7-7f0d01326189",
"name": "image",
"type": "IMAGE",
"linkIds": [
5,
7
],
"localized_name": "image",
"pos": [
-6601.620008222246,
1499.2619799128663
]
},
{
"id": "e89d2cd8-daa3-4e29-8a69-851db85072cb",
"name": "bg_removal_name",
"type": "COMBO",
"linkIds": [
12
],
"pos": [
-6601.620008222246,
1519.2619799128663
]
}
],
"outputs": [
{
"id": "16e7863c-4c38-46c2-aa74-e82991fbfe8d",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
8
],
"localized_name": "IMAGE",
"pos": [
-6145.049695722246,
1499.2619799128663
]
},
{
"id": "f7240c19-5b80-406e-a8e2-9b12440ee2d6",
"name": "mask",
"type": "MASK",
"linkIds": [
11
],
"pos": [
-6145.049695722246,
1519.2619799128663
]
}
],
"widgets": [],
"nodes": [
{
"id": 13,
"type": "RemoveBackground",
"pos": [
-6536.764823982709,
1444.9963409012412
],
"size": [
302.25,
72
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 5
},
{
"localized_name": "bg_removal_model",
"name": "bg_removal_model",
"type": "BACKGROUND_REMOVAL",
"link": 3
}
],
"outputs": [
{
"localized_name": "mask",
"name": "mask",
"type": "MASK",
"links": [
4,
11
]
}
],
"properties": {
"Node name for S&R": "RemoveBackground"
}
},
{
"id": 14,
"type": "LoadBackgroundRemovalModel",
"pos": [
-6540.534070722246,
1302.223464635445
],
"size": [
311.484375,
85.515625
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "bg_removal_name",
"name": "bg_removal_name",
"type": "COMBO",
"widget": {
"name": "bg_removal_name"
},
"link": 12
}
],
"outputs": [
{
"localized_name": "bg_model",
"name": "bg_model",
"type": "BACKGROUND_REMOVAL",
"links": [
3
]
}
],
"properties": {
"Node name for S&R": "LoadBackgroundRemovalModel",
"models": [
{
"name": "birefnet.safetensors",
"url": "https://huggingface.co/Comfy-Org/BiRefNet/resolve/main/background_removal/birefnet.safetensors",
"directory": "background_removal"
}
]
},
"widgets_values": [
"birefnet.safetensors"
]
},
{
"id": 15,
"type": "InvertMask",
"pos": [
-6532.446160529669,
1571.1111286839914
],
"size": [
285.984375,
48
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "mask",
"name": "mask",
"type": "MASK",
"link": 4
}
],
"outputs": [
{
"localized_name": "MASK",
"name": "MASK",
"type": "MASK",
"links": [
6
]
}
],
"properties": {
"Node name for S&R": "InvertMask"
}
},
{
"id": 16,
"type": "JoinImageWithAlpha",
"pos": [
-6527.4370171636665,
1674.3004951902876
],
"size": [
284.96875,
72
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 7
},
{
"localized_name": "alpha",
"name": "alpha",
"type": "MASK",
"link": 6
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
8
]
}
],
"properties": {
"Node name for S&R": "JoinImageWithAlpha"
}
}
],
"groups": [],
"links": [
{
"id": 3,
"origin_id": 14,
"origin_slot": 0,
"target_id": 13,
"target_slot": 1,
"type": "BACKGROUND_REMOVAL"
},
{
"id": 4,
"origin_id": 13,
"origin_slot": 0,
"target_id": 15,
"target_slot": 0,
"type": "MASK"
},
{
"id": 6,
"origin_id": 15,
"origin_slot": 0,
"target_id": 16,
"target_slot": 1,
"type": "MASK"
},
{
"id": 5,
"origin_id": -10,
"origin_slot": 0,
"target_id": 13,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 7,
"origin_id": -10,
"origin_slot": 0,
"target_id": 16,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 8,
"origin_id": 16,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 11,
"origin_id": 13,
"origin_slot": 0,
"target_id": -20,
"target_slot": 1,
"type": "MASK"
},
{
"id": 12,
"origin_id": -10,
"origin_slot": 1,
"target_id": 14,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Image generation and editing/Background Removal"
}
]
},
"extra": {}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1030,7 +1030,7 @@
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
"description": "Generates images from prompts using FLUX.1 [dev]: a 12B rectified-flow MMDiT with dual CLIP plus T5-XXL text encoders and guidance-distilled sampling for sharp prompt following versus classic DDPM diffusion."
}
]
},

View File

@ -1024,7 +1024,7 @@
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
"description": "FLUX.1 Krea [dev] (Black Forest Labs × Krea): open-weight 12B rectified-flow text-to-image drop-in alongside FLUX.1 [dev], tuned away from overcooked saturation toward more natural diversity in people, realism, and style while keeping ecosystem compatibility."
}
]
},

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,22 +1,21 @@
{
"id": "1c3eaa76-5cfa-4dc7-8571-97a570324e01",
"revision": 0,
"last_node_id": 34,
"last_link_id": 40,
"last_node_id": 57,
"last_link_id": 0,
"nodes": [
{
"id": 5,
"type": "dfe9eb32-97c0-43a5-90d5-4fd37768d91b",
"id": 57,
"type": "f2fdebf6-dfaf-43b6-9eb2-7f70613cfdc1",
"pos": [
-2.5766491043910378e-05,
1229.999928629805
130,
200
],
"size": [
400,
470
],
"flags": {},
"order": 0,
"order": 1,
"mode": 0,
"inputs": [
{
@ -44,6 +43,22 @@
},
"link": null
},
{
"name": "seed",
"type": "INT",
"widget": {
"name": "seed"
},
"link": null
},
{
"name": "steps",
"type": "INT",
"widget": {
"name": "steps"
},
"link": null
},
{
"name": "unet_name",
"type": "COMBO",
@ -80,15 +95,15 @@
"properties": {
"proxyWidgets": [
[
"-1",
"27",
"text"
],
[
"-1",
"13",
"width"
],
[
"-1",
"13",
"height"
],
[
@ -97,19 +112,23 @@
],
[
"3",
"control_after_generate"
"steps"
],
[
"-1",
"28",
"unet_name"
],
[
"-1",
"30",
"clip_name"
],
[
"-1",
"29",
"vae_name"
],
[
"3",
"control_after_generate"
]
],
"cnr_id": "comfy-core",
@ -122,29 +141,21 @@
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"",
1024,
1024,
null,
null,
"z_image_turbo_bf16.safetensors",
"qwen_3_4b.safetensors",
"ae.safetensors"
]
"widgets_values": [],
"title": "Text to Image (Z-Image-Turbo)"
}
],
"links": [],
"groups": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "dfe9eb32-97c0-43a5-90d5-4fd37768d91b",
"id": "f2fdebf6-dfaf-43b6-9eb2-7f70613cfdc1",
"version": 1,
"state": {
"lastGroupId": 4,
"lastNodeId": 34,
"lastLinkId": 40,
"lastNodeId": 61,
"lastLinkId": 75,
"lastRerouteId": 0
},
"revision": 0,
@ -153,17 +164,17 @@
"inputNode": {
"id": -10,
"bounding": [
-80,
425,
-560,
480,
120,
160
200
]
},
"outputNode": {
"id": -20,
"bounding": [
1490,
415,
1670,
320,
120,
60
]
@ -178,8 +189,8 @@
],
"label": "prompt",
"pos": [
20,
445
-460,
500
]
},
{
@ -190,8 +201,8 @@
35
],
"pos": [
20,
465
-460,
520
]
},
{
@ -202,44 +213,68 @@
36
],
"pos": [
20,
485
-460,
540
]
},
{
"id": "23087d15-8412-4fbd-b71e-9b6d7ef76de1",
"id": "f77677f7-6bf6-4c19-a71f-c4a553d5981e",
"name": "seed",
"type": "INT",
"linkIds": [
71
],
"pos": [
-460,
560
]
},
{
"id": "ef9a9fb1-5983-4bc9-a60b-cf5aec48bff1",
"name": "steps",
"type": "INT",
"linkIds": [
72
],
"pos": [
-460,
580
]
},
{
"id": "a20a1b30-785f-4a04-bb6d-3d61adab9764",
"name": "unet_name",
"type": "COMBO",
"linkIds": [
38
73
],
"pos": [
20,
505
-460,
600
]
},
{
"id": "0677f5c3-2a3f-43d4-98ac-a4c56d5efdc0",
"id": "4af8fc2b-4655-4086-8240-45f8cb38c6f6",
"name": "clip_name",
"type": "COMBO",
"linkIds": [
39
74
],
"pos": [
20,
525
-460,
620
]
},
{
"id": "c85c0445-2641-48b1-bbca-95057edf2fcf",
"id": "4d518693-2807-439c-9cb6-cffd23ccba2c",
"name": "vae_name",
"type": "COMBO",
"linkIds": [
40
75
],
"pos": [
20,
545
-460,
640
]
}
],
@ -253,8 +288,8 @@
],
"localized_name": "IMAGE",
"pos": [
1510,
435
1690,
340
]
}
],
@ -264,15 +299,15 @@
"id": 30,
"type": "CLIPLoader",
"pos": [
109.99997264844609,
329.99999029608756
30,
420
],
"size": [
269.9869791666667,
106
270,
150
],
"flags": {},
"order": 0,
"order": 7,
"mode": 0,
"inputs": [
{
@ -282,7 +317,7 @@
"widget": {
"name": "clip_name"
},
"link": 39
"link": 74
},
{
"localized_name": "type",
@ -315,9 +350,9 @@
}
],
"properties": {
"Node name for S&R": "CLIPLoader",
"cnr_id": "comfy-core",
"ver": "0.3.73",
"Node name for S&R": "CLIPLoader",
"models": [
{
"name": "qwen_3_4b.safetensors",
@ -343,15 +378,15 @@
"id": 29,
"type": "VAELoader",
"pos": [
109.99997264844609,
479.9999847172637
30,
650
],
"size": [
269.9869791666667,
58
270,
110
],
"flags": {},
"order": 1,
"order": 6,
"mode": 0,
"inputs": [
{
@ -361,7 +396,7 @@
"widget": {
"name": "vae_name"
},
"link": 40
"link": 75
}
],
"outputs": [
@ -375,9 +410,9 @@
}
],
"properties": {
"Node name for S&R": "VAELoader",
"cnr_id": "comfy-core",
"ver": "0.3.73",
"Node name for S&R": "VAELoader",
"models": [
{
"name": "ae.safetensors",
@ -401,12 +436,12 @@
"id": 33,
"type": "ConditioningZeroOut",
"pos": [
639.9999103333332,
620.0000271257795
630,
960
],
"size": [
204.134765625,
26
230,
80
],
"flags": {},
"order": 8,
@ -430,9 +465,9 @@
}
],
"properties": {
"Node name for S&R": "ConditioningZeroOut",
"cnr_id": "comfy-core",
"ver": "0.3.73",
"Node name for S&R": "ConditioningZeroOut",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
@ -440,22 +475,21 @@
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": []
}
},
{
"id": 8,
"type": "VAEDecode",
"pos": [
1219.9999088104782,
160.00009184959066
1320,
230
],
"size": [
209.98697916666669,
46
230,
100
],
"flags": {},
"order": 5,
"order": 1,
"mode": 0,
"inputs": [
{
@ -483,9 +517,9 @@
}
],
"properties": {
"Node name for S&R": "VAEDecode",
"cnr_id": "comfy-core",
"ver": "0.3.64",
"Node name for S&R": "VAEDecode",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
@ -493,22 +527,21 @@
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": []
}
},
{
"id": 28,
"type": "UNETLoader",
"pos": [
109.99997264844609,
200.0000502647102
30,
230
],
"size": [
269.9869791666667,
82
270,
110
],
"flags": {},
"order": 2,
"order": 5,
"mode": 0,
"inputs": [
{
@ -518,7 +551,7 @@
"widget": {
"name": "unet_name"
},
"link": 38
"link": 73
},
{
"localized_name": "weight_dtype",
@ -541,9 +574,9 @@
}
],
"properties": {
"Node name for S&R": "UNETLoader",
"cnr_id": "comfy-core",
"ver": "0.3.73",
"Node name for S&R": "UNETLoader",
"models": [
{
"name": "z_image_turbo_bf16.safetensors",
@ -568,15 +601,15 @@
"id": 27,
"type": "CLIPTextEncode",
"pos": [
429.99997828947767,
200.0000502647102
400,
230
],
"size": [
409.9869791666667,
319.9869791666667
450,
650
],
"flags": {},
"order": 7,
"order": 4,
"mode": 0,
"inputs": [
{
@ -607,9 +640,9 @@
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core",
"ver": "0.3.73",
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
@ -626,15 +659,15 @@
"id": 13,
"type": "EmptySD3LatentImage",
"pos": [
109.99997264844609,
629.9999791384399
40,
890
],
"size": [
259.9869791666667,
106
260,
170
],
"flags": {},
"order": 6,
"order": 3,
"mode": 0,
"inputs": [
{
@ -677,9 +710,9 @@
}
],
"properties": {
"Node name for S&R": "EmptySD3LatentImage",
"cnr_id": "comfy-core",
"ver": "0.3.64",
"Node name for S&R": "EmptySD3LatentImage",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
@ -694,19 +727,77 @@
1
]
},
{
"id": 11,
"type": "ModelSamplingAuraFlow",
"pos": [
950,
230
],
"size": [
310,
110
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 26
},
{
"localized_name": "shift",
"name": "shift",
"type": "FLOAT",
"widget": {
"name": "shift"
},
"link": null
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
13
]
}
],
"properties": {
"Node name for S&R": "ModelSamplingAuraFlow",
"cnr_id": "comfy-core",
"ver": "0.3.64",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
3
]
},
{
"id": 3,
"type": "KSampler",
"pos": [
879.9999615530063,
269.9999774911694
950,
400
],
"size": [
314.9869791666667,
262
320,
350
],
"flags": {},
"order": 4,
"order": 0,
"mode": 0,
"inputs": [
{
@ -740,7 +831,7 @@
"widget": {
"name": "seed"
},
"link": null
"link": 71
},
{
"localized_name": "steps",
@ -749,7 +840,7 @@
"widget": {
"name": "steps"
},
"link": null
"link": 72
},
{
"localized_name": "cfg",
@ -800,9 +891,9 @@
}
],
"properties": {
"Node name for S&R": "KSampler",
"cnr_id": "comfy-core",
"ver": "0.3.64",
"Node name for S&R": "KSampler",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
@ -814,81 +905,23 @@
"widgets_values": [
0,
"randomize",
4,
8,
1,
"res_multistep",
"simple",
1
]
},
{
"id": 11,
"type": "ModelSamplingAuraFlow",
"pos": [
879.9999615530063,
160.00009184959066
],
"size": [
309.9869791666667,
58
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 26
},
{
"localized_name": "shift",
"name": "shift",
"type": "FLOAT",
"widget": {
"name": "shift"
},
"link": null
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"slot_index": 0,
"links": [
13
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.3.64",
"Node name for S&R": "ModelSamplingAuraFlow",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
3
]
}
],
"groups": [
{
"id": 2,
"title": "Image size",
"title": "Step2 - Image size",
"bounding": [
100,
560,
290,
200
10,
820,
320,
280
],
"color": "#3f789e",
"font_size": 24,
@ -896,12 +929,12 @@
},
{
"id": 3,
"title": "Prompt",
"title": "Step3 - Prompt",
"bounding": [
410,
360,
130,
450,
540
530,
970
],
"color": "#3f789e",
"font_size": 24,
@ -909,12 +942,12 @@
},
{
"id": 4,
"title": "Models",
"title": "Step1 - Load models",
"bounding": [
100,
0,
130,
290,
413.6
330,
660
],
"color": "#3f789e",
"font_size": 24,
@ -1027,25 +1060,41 @@
"type": "INT"
},
{
"id": 38,
"id": 71,
"origin_id": -10,
"origin_slot": 3,
"target_id": 3,
"target_slot": 4,
"type": "INT"
},
{
"id": 72,
"origin_id": -10,
"origin_slot": 4,
"target_id": 3,
"target_slot": 5,
"type": "INT"
},
{
"id": 73,
"origin_id": -10,
"origin_slot": 5,
"target_id": 28,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 39,
"id": 74,
"origin_id": -10,
"origin_slot": 4,
"origin_slot": 6,
"target_id": 30,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 40,
"id": 75,
"origin_id": -10,
"origin_slot": 5,
"origin_slot": 7,
"target_id": 29,
"target_slot": 0,
"type": "COMBO"
@ -1059,21 +1108,5 @@
}
]
},
"config": {},
"extra": {
"frontendVersion": "1.37.10",
"workflowRendererVersion": "LG",
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0,
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true,
"ds": {
"scale": 0.8401370345180755,
"offset": [
940.0587067393087,
-830.7121087564725
]
}
},
"version": 0.4
"extra": {}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,827 @@
{
"revision": 0,
"last_node_id": 130,
"last_link_id": 0,
"nodes": [
{
"id": 130,
"type": "7937cf78-b52b-40a3-93b2-b4e2e5f98df1",
"pos": [
-1210,
-2780
],
"size": [
300,
370
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": null
},
{
"name": "text",
"type": "STRING",
"widget": {
"name": "text"
},
"link": null
},
{
"name": "bboxes",
"type": "BOUNDING_BOX",
"link": null
},
{
"name": "positive_coords",
"type": "STRING",
"link": null
},
{
"name": "negative_coords",
"type": "STRING",
"link": null
},
{
"name": "threshold",
"type": "FLOAT",
"widget": {
"name": "threshold"
},
"link": null
},
{
"name": "refine_iterations",
"type": "INT",
"widget": {
"name": "refine_iterations"
},
"link": null
},
{
"name": "individual_masks",
"type": "BOOLEAN",
"widget": {
"name": "individual_masks"
},
"link": null
},
{
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "masks",
"name": "masks",
"type": "MASK",
"links": []
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": []
},
{
"name": "audio",
"type": "AUDIO",
"links": null
},
{
"name": "fps",
"type": "FLOAT",
"links": null
}
],
"properties": {
"proxyWidgets": [
[
"125",
"text"
],
[
"126",
"threshold"
],
[
"126",
"refine_iterations"
],
[
"126",
"individual_masks"
],
[
"127",
"ckpt_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [],
"title": "Video Segmentation (SAM3)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "7937cf78-b52b-40a3-93b2-b4e2e5f98df1",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 130,
"lastLinkId": 299,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Video Segmentation (SAM3)",
"inputNode": {
"id": -10,
"bounding": [
-2260,
-3450,
136.369140625,
220
]
},
"outputNode": {
"id": -20,
"bounding": [
-1050,
-3510,
120,
120
]
},
"inputs": [
{
"id": "680ffd88-32fe-48be-88d6-91ea44d5eaee",
"name": "video",
"type": "VIDEO",
"linkIds": [
252
],
"pos": [
-2143.630859375,
-3430
]
},
{
"id": "ceaf249c-32d7-4624-8bf6-e590e347ed90",
"name": "text",
"type": "STRING",
"linkIds": [
254
],
"pos": [
-2143.630859375,
-3410
]
},
{
"id": "1ffbff36-da0c-4854-8cb4-88ad31e64f99",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
255
],
"pos": [
-2143.630859375,
-3390
]
},
{
"id": "67b7f4c7-cec0-4e00-b154-23cc1abf880e",
"name": "positive_coords",
"type": "STRING",
"linkIds": [
256
],
"pos": [
-2143.630859375,
-3370
]
},
{
"id": "b090a498-2bde-46b9-9554-18501401d687",
"name": "negative_coords",
"type": "STRING",
"linkIds": [
257
],
"pos": [
-2143.630859375,
-3350
]
},
{
"id": "1a76dfcf-ce95-46af-bba5-c42160c683dd",
"name": "threshold",
"type": "FLOAT",
"linkIds": [
261
],
"pos": [
-2143.630859375,
-3330
]
},
{
"id": "999523fa-c476-4c53-80c3-0a2f554d18ab",
"name": "refine_iterations",
"type": "INT",
"linkIds": [
262
],
"pos": [
-2143.630859375,
-3310
]
},
{
"id": "d2371011-7fe5-4a39-b0c1-df2e0bbd6ece",
"name": "individual_masks",
"type": "BOOLEAN",
"linkIds": [
263
],
"pos": [
-2143.630859375,
-3290
]
},
{
"id": "675a8b37-17db-48d1-853c-2fe5d6a74582",
"name": "ckpt_name",
"type": "COMBO",
"linkIds": [
273
],
"pos": [
-2143.630859375,
-3270
]
}
],
"outputs": [
{
"id": "ff50da09-1e59-4a58-9b7f-be1a00aa5913",
"name": "masks",
"type": "MASK",
"linkIds": [
231
],
"localized_name": "masks",
"pos": [
-1030,
-3490
]
},
{
"id": "8f622e40-8528-4078-b7d3-147e9f872194",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
232
],
"localized_name": "bboxes",
"pos": [
-1030,
-3470
]
},
{
"id": "6c9924ec-f0fa-4509-83ea-8f97f5889bcc",
"name": "audio",
"type": "AUDIO",
"linkIds": [
259
],
"pos": [
-1030,
-3450
]
},
{
"id": "82c1cddc-ab11-44eb-9e2f-1a5c7ea5645b",
"name": "fps",
"type": "FLOAT",
"linkIds": [
260
],
"pos": [
-1030,
-3430
]
}
],
"widgets": [],
"nodes": [
{
"id": 125,
"type": "CLIPTextEncode",
"pos": [
-2010,
-3040
],
"size": [
400,
200
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "clip",
"name": "clip",
"type": "CLIP",
"link": 240
},
{
"localized_name": "text",
"name": "text",
"type": "STRING",
"widget": {
"name": "text"
},
"link": 254
}
],
"outputs": [
{
"localized_name": "CONDITIONING",
"name": "CONDITIONING",
"type": "CONDITIONING",
"links": [
200
]
}
],
"properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
""
]
},
{
"id": 126,
"type": "SAM3_Detect",
"pos": [
-1520,
-3520
],
"size": [
270,
290
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"label": "model",
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 237
},
{
"label": "image",
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 253
},
{
"label": "conditioning",
"localized_name": "conditioning",
"name": "conditioning",
"shape": 7,
"type": "CONDITIONING",
"link": 200
},
{
"label": "bboxes",
"localized_name": "bboxes",
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": 255
},
{
"label": "positive_coords",
"localized_name": "positive_coords",
"name": "positive_coords",
"shape": 7,
"type": "STRING",
"link": 256
},
{
"label": "negative_coords",
"localized_name": "negative_coords",
"name": "negative_coords",
"shape": 7,
"type": "STRING",
"link": 257
},
{
"localized_name": "threshold",
"name": "threshold",
"type": "FLOAT",
"widget": {
"name": "threshold"
},
"link": 261
},
{
"localized_name": "refine_iterations",
"name": "refine_iterations",
"type": "INT",
"widget": {
"name": "refine_iterations"
},
"link": 262
},
{
"localized_name": "individual_masks",
"name": "individual_masks",
"type": "BOOLEAN",
"widget": {
"name": "individual_masks"
},
"link": 263
}
],
"outputs": [
{
"localized_name": "masks",
"name": "masks",
"type": "MASK",
"links": [
231
]
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": [
232
]
}
],
"properties": {
"Node name for S&R": "SAM3_Detect",
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
0.5,
2,
false
]
},
{
"id": 127,
"type": "CheckpointLoaderSimple",
"pos": [
-1970,
-3310
],
"size": [
330,
160
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "ckpt_name",
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": 273
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"links": [
237
]
},
{
"localized_name": "CLIP",
"name": "CLIP",
"type": "CLIP",
"links": [
240
]
},
{
"localized_name": "VAE",
"name": "VAE",
"type": "VAE",
"links": null
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple",
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"models": [
{
"name": "sam3.1_multiplex_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/sam3.1/resolve/main/checkpoints/sam3.1_multiplex_fp16.safetensors",
"directory": "checkpoints"
}
]
},
"widgets_values": [
"sam3.1_multiplex_fp16.safetensors"
]
},
{
"id": 128,
"type": "GetVideoComponents",
"pos": [
-1910,
-3540
],
"size": [
230,
120
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 252
}
],
"outputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"links": [
253
]
},
{
"localized_name": "audio",
"name": "audio",
"type": "AUDIO",
"links": [
259
]
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"links": [
260
]
}
],
"properties": {
"Node name for S&R": "GetVideoComponents",
"cnr_id": "comfy-core",
"ver": "0.19.3",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
}
},
{
"id": 129,
"type": "Note",
"pos": [
-1980,
-2790
],
"size": [
370,
250
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [],
"title": "Note: Prompt format",
"properties": {},
"widgets_values": [
"Max tokens for this model is only 32, to separately prompt multiple subjects you can separate prompts with comma, and set the max amount of objects detected for each prompt with :N\n\nFor example above test prompt finds 2 cakes, one apron, 4 window panels"
],
"color": "#432",
"bgcolor": "#653"
}
],
"groups": [],
"links": [
{
"id": 237,
"origin_id": 127,
"origin_slot": 0,
"target_id": 126,
"target_slot": 0,
"type": "MODEL"
},
{
"id": 200,
"origin_id": 125,
"origin_slot": 0,
"target_id": 126,
"target_slot": 2,
"type": "CONDITIONING"
},
{
"id": 240,
"origin_id": 127,
"origin_slot": 1,
"target_id": 125,
"target_slot": 0,
"type": "CLIP"
},
{
"id": 231,
"origin_id": 126,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "MASK"
},
{
"id": 232,
"origin_id": 126,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "BOUNDING_BOX"
},
{
"id": 252,
"origin_id": -10,
"origin_slot": 0,
"target_id": 128,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 253,
"origin_id": 128,
"origin_slot": 0,
"target_id": 126,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 254,
"origin_id": -10,
"origin_slot": 1,
"target_id": 125,
"target_slot": 1,
"type": "STRING"
},
{
"id": 255,
"origin_id": -10,
"origin_slot": 2,
"target_id": 126,
"target_slot": 3,
"type": "BOUNDING_BOX"
},
{
"id": 256,
"origin_id": -10,
"origin_slot": 3,
"target_id": 126,
"target_slot": 4,
"type": "STRING"
},
{
"id": 257,
"origin_id": -10,
"origin_slot": 4,
"target_id": 126,
"target_slot": 5,
"type": "STRING"
},
{
"id": 259,
"origin_id": 128,
"origin_slot": 1,
"target_id": -20,
"target_slot": 2,
"type": "AUDIO"
},
{
"id": 260,
"origin_id": 128,
"origin_slot": 2,
"target_id": -20,
"target_slot": 3,
"type": "FLOAT"
},
{
"id": 261,
"origin_id": -10,
"origin_slot": 5,
"target_id": 126,
"target_slot": 6,
"type": "FLOAT"
},
{
"id": 262,
"origin_id": -10,
"origin_slot": 6,
"target_id": 126,
"target_slot": 7,
"type": "INT"
},
{
"id": 263,
"origin_id": -10,
"origin_slot": 7,
"target_id": 126,
"target_slot": 8,
"type": "BOOLEAN"
},
{
"id": 273,
"origin_id": -10,
"origin_slot": 8,
"target_id": 127,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Video Tools",
"description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts."
}
]
},
"extra": {}
}

View File

@ -1,21 +1,21 @@
{
"revision": 0,
"last_node_id": 84,
"last_node_id": 85,
"last_link_id": 0,
"nodes": [
{
"id": 84,
"type": "8e8aa94a-647e-436d-8440-8ee4691864de",
"id": 85,
"type": "637913e7-0206-46ba-8ded-70ae3a7c2e19",
"pos": [
-6100,
2620
-880,
-2260
],
"size": [
290,
160
],
"flags": {},
"order": 0,
"order": 2,
"mode": 0,
"inputs": [
{
@ -76,31 +76,26 @@
"properties": {
"proxyWidgets": [
[
"-1",
"79",
"direction"
],
[
"-1",
"79",
"match_image_size"
],
[
"-1",
"79",
"spacing_width"
],
[
"-1",
"79",
"spacing_color"
]
],
"cnr_id": "comfy-core",
"ver": "0.13.0"
},
"widgets_values": [
"right",
true,
0,
"white"
],
"widgets_values": [],
"title": "Video Stitch"
}
],
@ -109,12 +104,12 @@
"definitions": {
"subgraphs": [
{
"id": "8e8aa94a-647e-436d-8440-8ee4691864de",
"id": "637913e7-0206-46ba-8ded-70ae3a7c2e19",
"version": 1,
"state": {
"lastGroupId": 1,
"lastNodeId": 84,
"lastLinkId": 262,
"lastNodeId": 97,
"lastLinkId": 282,
"lastRerouteId": 0
},
"revision": 0,
@ -123,8 +118,8 @@
"inputNode": {
"id": -10,
"bounding": [
-6580,
2649,
-6810,
2580,
143.55859375,
160
]
@ -132,8 +127,8 @@
"outputNode": {
"id": -20,
"bounding": [
-5720,
2659,
-4770,
2600,
120,
60
]
@ -149,8 +144,8 @@
"localized_name": "video",
"label": "Before Video",
"pos": [
-6456.44140625,
2669
-6686.44140625,
2600
]
},
{
@ -163,8 +158,8 @@
"localized_name": "video_1",
"label": "After Video",
"pos": [
-6456.44140625,
2689
-6686.44140625,
2620
]
},
{
@ -175,8 +170,8 @@
259
],
"pos": [
-6456.44140625,
2709
-6686.44140625,
2640
]
},
{
@ -187,8 +182,8 @@
260
],
"pos": [
-6456.44140625,
2729
-6686.44140625,
2660
]
},
{
@ -199,8 +194,8 @@
261
],
"pos": [
-6456.44140625,
2749
-6686.44140625,
2680
]
},
{
@ -211,8 +206,8 @@
262
],
"pos": [
-6456.44140625,
2769
-6686.44140625,
2700
]
}
],
@ -226,8 +221,8 @@
],
"localized_name": "VIDEO",
"pos": [
-5700,
2679
-4750,
2620
]
}
],
@ -238,11 +233,11 @@
"type": "GetVideoComponents",
"pos": [
-6390,
2560
2600
],
"size": [
193.530859375,
66
230,
120
],
"flags": {},
"order": 1,
@ -278,9 +273,9 @@
}
],
"properties": {
"Node name for S&R": "GetVideoComponents",
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "GetVideoComponents"
"ver": "0.13.0"
}
},
{
@ -291,8 +286,8 @@
2420
],
"size": [
193.530859375,
66
230,
120
],
"flags": {},
"order": 0,
@ -332,21 +327,254 @@
}
],
"properties": {
"Node name for S&R": "GetVideoComponents",
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "GetVideoComponents"
"ver": "0.13.0"
}
},
{
"id": 90,
"type": "GetImageSize",
"pos": [
-6390,
3030
],
"size": [
230,
120
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 266
}
],
"outputs": [
{
"localized_name": "width",
"name": "width",
"type": "INT",
"links": [
274
]
},
{
"localized_name": "height",
"name": "height",
"type": "INT",
"links": [
276
]
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSize"
}
},
{
"id": 80,
"type": "CreateVideo",
"pos": [
-5190,
2420
],
"size": [
270,
130
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"link": 282
},
{
"localized_name": "audio",
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": 251
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"widget": {
"name": "fps"
},
"link": 252
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"links": [
255
]
}
],
"properties": {
"Node name for S&R": "CreateVideo",
"cnr_id": "comfy-core",
"ver": "0.13.0"
},
"widgets_values": [
30
]
},
{
"id": 95,
"type": "ComfyMathExpression",
"pos": [
-6040,
3020
],
"size": [
400,
200
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT",
"link": 274
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
279
]
}
],
"properties": {
"Node name for S&R": "ComfyMathExpression"
},
"widgets_values": [
"a & ~1"
]
},
{
"id": 96,
"type": "ComfyMathExpression",
"pos": [
-6040,
3290
],
"size": [
400,
200
],
"flags": {},
"order": 6,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT",
"link": 276
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
280
]
}
],
"properties": {
"Node name for S&R": "ComfyMathExpression"
},
"widgets_values": [
"a & ~1"
]
},
{
"id": 79,
"type": "ImageStitch",
"pos": [
-6390,
2700
2780
],
"size": [
270,
150
160
],
"flags": {},
"order": 2,
@ -408,14 +636,15 @@
"name": "IMAGE",
"type": "IMAGE",
"links": [
250
266,
281
]
}
],
"properties": {
"Node name for S&R": "ImageStitch",
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "ImageStitch"
"ver": "0.13.0"
},
"widgets_values": [
"right",
@ -425,60 +654,91 @@
]
},
{
"id": 80,
"type": "CreateVideo",
"id": 97,
"type": "ResizeImageMaskNode",
"pos": [
-6040,
2610
-5560,
2790
],
"size": [
270,
78
160
],
"flags": {},
"order": 3,
"order": 7,
"mode": 0,
"inputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"link": 250
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": 281
},
{
"localized_name": "audio",
"name": "audio",
"shape": 7,
"type": "AUDIO",
"link": 251
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"localized_name": "resize_type",
"name": "resize_type",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "fps"
"name": "resize_type"
},
"link": 252
"link": null
},
{
"localized_name": "width",
"name": "resize_type.width",
"type": "INT",
"widget": {
"name": "resize_type.width"
},
"link": 279
},
{
"localized_name": "height",
"name": "resize_type.height",
"type": "INT",
"widget": {
"name": "resize_type.height"
},
"link": 280
},
{
"localized_name": "crop",
"name": "resize_type.crop",
"type": "COMBO",
"widget": {
"name": "resize_type.crop"
},
"link": null
},
{
"localized_name": "scale_method",
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": null
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"localized_name": "resized",
"name": "resized",
"type": "*",
"links": [
255
282
]
}
],
"properties": {
"cnr_id": "comfy-core",
"ver": "0.13.0",
"Node name for S&R": "CreateVideo"
"Node name for S&R": "ResizeImageMaskNode"
},
"widgets_values": [
30
"scale dimensions",
512,
512,
"center",
"area"
]
}
],
@ -500,14 +760,6 @@
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 250,
"origin_id": 79,
"origin_slot": 0,
"target_id": 80,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 251,
"origin_id": 77,
@ -579,6 +831,62 @@
"target_id": 79,
"target_slot": 5,
"type": "COMBO"
},
{
"id": 266,
"origin_id": 79,
"origin_slot": 0,
"target_id": 90,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 274,
"origin_id": 90,
"origin_slot": 0,
"target_id": 95,
"target_slot": 0,
"type": "INT"
},
{
"id": 276,
"origin_id": 90,
"origin_slot": 1,
"target_id": 96,
"target_slot": 0,
"type": "INT"
},
{
"id": 279,
"origin_id": 95,
"origin_slot": 1,
"target_id": 97,
"target_slot": 2,
"type": "INT"
},
{
"id": 280,
"origin_id": 96,
"origin_slot": 1,
"target_id": 97,
"target_slot": 3,
"type": "INT"
},
{
"id": 281,
"origin_id": 79,
"origin_slot": 0,
"target_id": 97,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 282,
"origin_id": 97,
"origin_slot": 0,
"target_id": 80,
"target_slot": 0,
"type": "IMAGE"
}
],
"extra": {
@ -588,5 +896,6 @@
"description": "Stitches multiple video clips into a single sequential video file."
}
]
}
},
"extra": {}
}

View File

@ -141,8 +141,7 @@ manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", he
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
vram_group.add_argument("--lowvram", action="store_true", help="Doesn't do anything if dynamic vram is enabled. If dynamic vram isn't being used this option makes the text encoders run on the CPU.")
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")

View File

@ -106,6 +106,7 @@ class Dino2Encoder(torch.nn.Module):
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
@ -125,17 +126,37 @@ class Dino2Embeddings(torch.nn.Module):
super().__init__()
patch_size = 14
image_size = 518
self.patch_size = patch_size
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
class_pos = pos_embed[:, 0:1]
patch_pos = pos_embed[:, 1:]
N = patch_pos.shape[1]
M = int(N ** 0.5)
h0 = h_pixels // self.patch_size
w0 = w_pixels // self.patch_size
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
if x.shape[1] - 1 == self.position_embeddings.shape[1] - 1:
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
else:
h, w = pixel_values.shape[-2:]
x = x + self.interpolate_pos_encoding(x, h, w)
return x
@ -158,3 +179,21 @@ class Dinov2Model(torch.nn.Module):
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
x = self.embeddings(pixel_values)
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
n_layers = len(self.encoder.layer)
resolved = [(i if i >= 0 else n_layers + i) for i in indices]
target = set(resolved)
max_idx = max(resolved)
n_skip = 1 # skip cls token
cache = {}
for i, layer in enumerate(self.encoder.layer):
x = layer(x, optimized_attention)
if i in target:
normed = self.layernorm(x) if apply_norm else x
cache[i] = (normed[:, n_skip:], normed[:, 0])
if i >= max_idx:
break
return [cache[i] for i in resolved]

View File

@ -242,6 +242,7 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -373,6 +374,7 @@ def sample_dpm_2_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -686,6 +688,7 @@ def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=Non
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
@ -747,6 +750,7 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
@ -832,6 +836,7 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
old_denoised = None
h, h_last = None, None
@ -889,6 +894,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
denoised_1, denoised_2 = None, None
h, h_1, h_2 = None, None, None
@ -1006,23 +1012,39 @@ def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None,
return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
@torch.no_grad()
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, s_noise=1.0, s_noise_end=None, noise_clip_std=0.0):
# s_noise / s_noise_end: per-step noise multiplier, linearly interpolated across steps
# noise_clip_std: clamp injected noise to +/- N stddevs (0 disables).
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable):
n_steps = max(1, len(sigmas) - 1)
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_start = float(s_noise)
s_end = s_start if s_noise_end is None else float(s_noise_end)
for i in trange(n_steps, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
x = denoised
if sigmas[i + 1] > 0:
x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
noise = noise_sampler(sigmas[i], sigmas[i + 1])
if noise_clip_std > 0:
clip_val = noise_clip_std * noise.std()
noise = noise.clamp(min=-clip_val, max=clip_val)
t = (i / (n_steps - 1)) if n_steps > 1 else 0.0
s_noise_i = s_start + (s_end - s_start) * t
if s_noise_i != 1.0:
noise = noise * s_noise_i
x = model_sampling.noise_scaling(sigmas[i + 1], noise, x)
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
@ -1249,6 +1271,7 @@ def sample_euler_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=No
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
uncond_denoised = None
@ -1296,6 +1319,7 @@ def sample_dpmpp_2s_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
temp = [0]
def post_cfg_function(args):
@ -1371,6 +1395,7 @@ def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@ -1504,6 +1529,7 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_noise = s_noise * getattr(model.inner_model.model_patcher.get_model_object('model_sampling'), "noise_scale", 1.0)
s_in = x.new_ones([x.shape[0]])
def default_er_sde_noise_scaler(x):
@ -1574,9 +1600,10 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1645,9 +1672,10 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
inject_noise = eta > 0 and s_noise > 0
sigma_fn = partial(half_log_snr_to_sigma, model_sampling=model_sampling)
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
@ -1713,6 +1741,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.model_patcher.get_model_object("model_sampling")
s_noise = s_noise * getattr(model_sampling, "noise_scale", 1.0)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
lambdas = sigma_to_half_log_snr(sigmas, model_sampling=model_sampling)

View File

@ -792,6 +792,13 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -0,0 +1,41 @@
"""HiDream-O1 two-pass attention: tokens [0, ar_len) are causal, [ar_len, T)
attend full K/V. Splitting Q at the boundary avoids the (B, 1, T, T) additive
mask the general-purpose path would build (~500 MB at T~16K) and lets the
gen half hit the user's preferred backend via optimized_attention.
"""
import torch
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention
def make_two_pass_attention(ar_len: int, transformer_options=None):
"""Build a two-pass attention callable. AR pass uses SDPA-causal directly, gen pass routes through optimized_attention.
The AR pass goes through SDPA directand bypasses wrappers, it is only ~1% of T at typical edit sizes.
"""
def two_pass_attention(q, k, v, heads, **kwargs):
B, H, T, D = q.shape
if T < k.shape[2]: # KV-cache hot path: Q is shorter than K/V (cached AR prefix is in K/V only), all fresh Q positions are in the gen region, single full-attention call
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
elif ar_len >= T:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
elif ar_len <= 0:
out = optimized_attention(q, k, v, heads, mask=None, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
else:
out_ar = comfy.ops.scaled_dot_product_attention(
q[:, :, :ar_len], k[:, :, :ar_len], v[:, :, :ar_len],
attn_mask=None, dropout_p=0.0, is_causal=True,
)
out_gen = optimized_attention(
q[:, :, ar_len:], k, v, heads,
mask=None, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out = torch.cat([out_ar, out_gen], dim=2)
return out.transpose(1, 2).reshape(B, T, H * D)
return two_pass_attention

View File

@ -0,0 +1,230 @@
"""HiDream-O1 conditioning prep — ref-image dual path + extra_conds assembly.
Each ref image goes through two paths: a 32x32 patchified stream concatenated
to the noised target, and a Qwen3-VL ViT path producing tokens that scatter
into input_ids at <|image_pad|> positions.
"""
from typing import List
import torch
import comfy.utils
from comfy.text_encoders.qwen_vl import process_qwen2vl_images
from .utils import (PATCH_SIZE, calculate_dimensions, cond_image_size, ref_max_size, resize_tensor)
# Qwen3-VL ViT preprocessing constants (preprocessor_config.json).
VIT_PATCH = 16
VIT_MERGE = 2
VIT_IMAGE_MEAN = [0.5, 0.5, 0.5]
VIT_IMAGE_STD = [0.5, 0.5, 0.5]
def prepare_ref_images(
ref_images: List[torch.Tensor],
target_h: int,
target_w: int,
device: torch.device,
dtype: torch.dtype,
):
"""Build the dual-path tensors for K reference images at (target_h, target_w).
Returns None for K=0, else a dict with ref_patches, ref_pixel_values,
ref_image_grid_thw, per_ref_vit_tokens, per_ref_patch_grids.
"""
K = len(ref_images)
if K == 0:
return None
max_size = ref_max_size(max(target_h, target_w), K)
cis = cond_image_size(K)
refs_t = [img[0].clamp(0, 1).permute(2, 0, 1).unsqueeze(0).contiguous().float() for img in ref_images]
refs_t = [resize_tensor(t, max_size, PATCH_SIZE) for t in refs_t]
# 32-patch path.
ref_patches_per = []
per_ref_patch_grids = []
for t in refs_t:
t_norm = (t.squeeze(0) - 0.5) / 0.5 # (3, H, W) in [-1, 1]
h_p, w_p = t_norm.shape[-2] // PATCH_SIZE, t_norm.shape[-1] // PATCH_SIZE
per_ref_patch_grids.append((h_p, w_p))
patches = (
t_norm.reshape(3, h_p, PATCH_SIZE, w_p, PATCH_SIZE)
.permute(1, 3, 0, 2, 4)
.reshape(h_p * w_p, 3 * PATCH_SIZE * PATCH_SIZE)
)
ref_patches_per.append(patches)
ref_patches = torch.cat(ref_patches_per, dim=0).unsqueeze(0).to(device=device, dtype=dtype)
# ViT path.
refs_vlm_t = []
for t in refs_t:
_, _, h, w = t.shape
cond_w, cond_h = calculate_dimensions(cis, w / h)
cond_w = max(cond_w, VIT_PATCH * VIT_MERGE)
cond_h = max(cond_h, VIT_PATCH * VIT_MERGE)
refs_vlm_t.append(comfy.utils.common_upscale(t, cond_w, cond_h, "lanczos", "disabled"))
pv_list, grid_list, per_ref_vit_tokens = [], [], []
for t_v in refs_vlm_t:
pv, grid_thw = process_qwen2vl_images(
t_v.permute(0, 2, 3, 1),
min_pixels=0, max_pixels=10**12,
patch_size=VIT_PATCH, merge_size=VIT_MERGE,
image_mean=VIT_IMAGE_MEAN, image_std=VIT_IMAGE_STD,
)
grid_thw = grid_thw[0]
pv_list.append(pv.to(device=device, dtype=dtype))
grid_list.append(grid_thw.to(device=device))
# Post-merge token count = number of <|image_pad|> tokens this image expands to in input_ids.
gh, gw = int(grid_thw[1].item()), int(grid_thw[2].item())
per_ref_vit_tokens.append((gh // VIT_MERGE) * (gw // VIT_MERGE))
return {
"ref_patches": ref_patches,
"ref_pixel_values": torch.cat(pv_list, dim=0),
"ref_image_grid_thw": torch.stack(grid_list, dim=0),
"per_ref_vit_tokens": per_ref_vit_tokens,
"per_ref_patch_grids": per_ref_patch_grids,
}
def build_ref_input_ids(
text_input_ids: torch.Tensor,
per_ref_vit_tokens: List[int],
image_token_id: int,
vision_start_id: int,
vision_end_id: int,
):
"""Splice [vision_start, image_pad*N, vision_end] blocks into input_ids
after the [im_start, user, \\n] prefix (matches original chat template).
"""
ids = text_input_ids[0].tolist()
inserted = []
for n_pad in per_ref_vit_tokens:
inserted.extend([vision_start_id] + [image_token_id] * n_pad + [vision_end_id])
new_ids = ids[:3] + inserted + ids[3:] # 3 = len([im_start, user, \n])
return torch.tensor([new_ids], dtype=text_input_ids.dtype, device=text_input_ids.device)
def build_extra_conds(
text_input_ids: torch.Tensor,
noise: torch.Tensor,
ref_images: List[torch.Tensor] = None,
target_patch_size: int = 32,
):
"""Assemble all conditioning tensors for HiDreamO1Transformer.forward:
input_ids (with ref-vision tokens spliced in for the edit/IP path),
position_ids (MRoPE), token_types, vinput_mask, plus the ref
dual-path tensors when refs are provided.
"""
from .utils import get_rope_index_fix_point
from comfy.text_encoders.hidream_o1 import (
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
if text_input_ids.dim() == 1:
text_input_ids = text_input_ids.unsqueeze(0)
text_input_ids = text_input_ids.long().to(noise.device)
B = noise.shape[0]
if text_input_ids.shape[0] == 1 and B > 1:
text_input_ids = text_input_ids.expand(B, -1)
H, W = noise.shape[-2], noise.shape[-1]
h_p, w_p = H // target_patch_size, W // target_patch_size
image_len = h_p * w_p
image_grid_thw_tgt = torch.tensor(
[[1, h_p, w_p]], dtype=torch.long, device=text_input_ids.device,
)
out = {}
if ref_images:
ref = prepare_ref_images(ref_images, H, W, device=noise.device, dtype=noise.dtype)
text_input_ids = build_ref_input_ids(
text_input_ids, ref["per_ref_vit_tokens"],
IMAGE_TOKEN_ID, VISION_START_ID, VISION_END_ID,
)
new_txt_len = text_input_ids.shape[1]
# Each ref's patchified stream gets a [vision_start, image_pad*N-1]
# block in the position-id stream after the noised target.
ref_grid_lengths = [hp * wp for (hp, wp) in ref["per_ref_patch_grids"]]
tgt_vision = torch.full((1, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
tgt_vision[:, 0] = VISION_START_ID
ref_vision_blocks = []
for rl in ref_grid_lengths:
blk = torch.full((1, rl), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
blk[:, 0] = VISION_START_ID
ref_vision_blocks.append(blk)
ref_vision_cat = torch.cat([tgt_vision] + ref_vision_blocks, dim=1)
input_ids_pad = torch.cat([text_input_ids, ref_vision_cat], dim=-1)
total_ref_patches_len = sum(ref_grid_lengths)
total_len = new_txt_len + image_len + total_ref_patches_len
# K (ViT, post-merge) + 1 (target) + K (ref-patches) image grids.
K = len(ref_images)
igthw_cond = ref["ref_image_grid_thw"].clone()
igthw_cond[:, 1] //= 2
igthw_cond[:, 2] //= 2
image_grid_thw_ref = torch.tensor(
[[1, hp, wp] for (hp, wp) in ref["per_ref_patch_grids"]],
dtype=torch.long, device=text_input_ids.device,
)
igthw_all = torch.cat([
igthw_cond.to(text_input_ids.device),
image_grid_thw_tgt,
image_grid_thw_ref,
], dim=0)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=igthw_all,
attention_mask=None,
skip_vision_start_token=[0] * K + [1] + [1] * K,
fix_point=4096,
)
# tms + target_image + ref_patches are all gen.
tms_pos = new_txt_len - 1
ar_len = tms_pos
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, tms_pos:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, new_txt_len:] = True
# Leading batch dim sidesteps CONDRegular.process_cond's repeat_to_batch_size truncation
out["ref_pixel_values"] = ref["ref_pixel_values"].unsqueeze(0)
out["ref_image_grid_thw"] = ref["ref_image_grid_thw"].unsqueeze(0)
out["ref_patches"] = ref["ref_patches"]
else:
# T2I: text + noised target only, vision_start replaces the first image token
txt_len = text_input_ids.shape[1]
total_len = txt_len + image_len
vision_tokens = torch.full((B, image_len), IMAGE_TOKEN_ID,
dtype=text_input_ids.dtype, device=text_input_ids.device)
vision_tokens[:, 0] = VISION_START_ID
input_ids_pad = torch.cat([text_input_ids, vision_tokens], dim=-1)
position_ids, _ = get_rope_index_fix_point(
spatial_merge_size=1,
image_token_id=IMAGE_TOKEN_ID,
vision_start_token_id=VISION_START_ID,
input_ids=input_ids_pad, image_grid_thw=image_grid_thw_tgt,
attention_mask=None,
skip_vision_start_token=[1],
)
ar_len = txt_len - 1
token_types = torch.zeros(B, total_len, dtype=torch.long, device=noise.device)
token_types[:, ar_len:] = 1
vinput_mask = torch.zeros(B, total_len, dtype=torch.bool, device=noise.device)
vinput_mask[:, txt_len:] = True
out["input_ids"] = text_input_ids
out["position_ids"] = position_ids[:, 0].unsqueeze(0) # Collapse position_ids batch and add a leading dim so CONDRegular's batch-resize doesn't truncate the 3-axis MRoPE dim
out["token_types"] = token_types
out["vinput_mask"] = vinput_mask
out["ar_len"] = ar_len
return out

View File

@ -0,0 +1,306 @@
"""HiDream-O1-Image transformer.
Pixel-space DiT built on Qwen3-VL: the vision tower (Qwen35VisionModel)
encodes ref images, the Qwen3-VL-8B decoder (Llama2_ with interleaved MRoPE)
processes a unified text+image sequence, and 32x32 patch embed/unembed
shims map raw RGB in and out of LLM hidden space. The Qwen3-VL deepstack
mergers go unused their weights are dropped at load.
"""
from dataclasses import dataclass, field
from typing import List, Optional
import einops
import torch
import torch.nn as nn
import comfy.patcher_extension
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.text_encoders.llama import Llama2_
from comfy.text_encoders.qwen35 import Qwen35VisionModel
from .attention import make_two_pass_attention
IMAGE_TOKEN_ID = 151655 # Qwen3-VL <|image_pad|>
TMS_TOKEN_ID = 151673 # HiDream-O1 <|tms_token|>
PATCH_SIZE = 32
@dataclass
class HiDreamO1TextConfig:
"""Qwen3-VL-8B text-decoder dims (matches public Qwen3-VL-8B-Instruct)."""
vocab_size: int = 151936
hidden_size: int = 4096
intermediate_size: int = 12288
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
head_dim: int = 128
max_position_embeddings: int = 128000
rms_norm_eps: float = 1e-6
rope_theta: float = 5000000.0
rope_scale: Optional[float] = None
rope_dims: List[int] = field(default_factory=lambda: [24, 20, 20])
interleaved_mrope: bool = True
transformer_type: str = "llama"
rms_norm_add: bool = False
mlp_activation: str = "silu"
qkv_bias: bool = False
q_norm: str = "gemma3"
k_norm: str = "gemma3"
final_norm: bool = True
lm_head: bool = False
stop_tokens: List[int] = field(default_factory=lambda: [151643, 151645])
QWEN3VL_VISION_DEFAULTS = dict(
hidden_size=1152,
num_heads=16,
intermediate_size=4304,
depth=27,
patch_size=16,
temporal_patch_size=2,
in_channels=3,
spatial_merge_size=2,
num_position_embeddings=2304,
deepstack_visual_indexes=(8, 16, 24),
out_hidden_size=4096, # final merger projects directly into LLM hidden
)
class BottleneckPatchEmbed(nn.Module):
# 3072 -> 1024 -> 4096 (raw 32x32 RGB patch -> bottleneck -> LLM hidden).
def __init__(self, patch_size=32, in_chans=3, pca_dim=1024, embed_dim=4096, bias=True, device=None, dtype=None, ops=None):
super().__init__()
self.proj1 = ops.Linear(patch_size * patch_size * in_chans, pca_dim, bias=False, device=device, dtype=dtype)
self.proj2 = ops.Linear(pca_dim, embed_dim, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.proj2(self.proj1(x))
class FinalLayer(nn.Module):
# 4096 -> 3072 (LLM hidden -> flat pixel patch).
def __init__(self, hidden_size, patch_size=32, out_channels=3, device=None, dtype=None, ops=None):
super().__init__()
self.linear = ops.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.linear(x)
class HiDreamO1Transformer(nn.Module):
"""HiDream-O1 unified pixel-level transformer."""
def __init__(self, image_model=None, dtype=None, device=None, operations=None,
text_config_overrides=None, vision_config_overrides=None, **kwargs):
super().__init__()
self.dtype = dtype
text_cfg = HiDreamO1TextConfig(**(text_config_overrides or {}))
vision_cfg = dict(QWEN3VL_VISION_DEFAULTS)
if vision_config_overrides:
vision_cfg.update(vision_config_overrides)
vision_cfg["out_hidden_size"] = text_cfg.hidden_size
self.text_config = text_cfg
self.vision_config = vision_cfg
self.hidden_size = text_cfg.hidden_size
self.patch_size = PATCH_SIZE
self.in_channels = 3
self.tms_token_id = TMS_TOKEN_ID
self.visual = Qwen35VisionModel(vision_cfg, device=device, dtype=dtype, ops=operations)
self.language_model = Llama2_(text_cfg, device=device, dtype=dtype, ops=operations)
self.t_embedder1 = TimestepEmbedder(
text_cfg.hidden_size, device=device, dtype=dtype, operations=operations,
)
self.x_embedder = BottleneckPatchEmbed(
patch_size=self.patch_size, in_chans=self.in_channels,
pca_dim=text_cfg.hidden_size // 4, embed_dim=text_cfg.hidden_size,
bias=True, device=device, dtype=dtype, ops=operations,
)
self.final_layer2 = FinalLayer(
text_cfg.hidden_size, patch_size=self.patch_size,
out_channels=self.in_channels, device=device, dtype=dtype, ops=operations,
)
self._visual_cache = None
self._kv_cache_entries = []
def clear_kv_cache(self):
self._kv_cache_entries = []
self._visual_cache = None
def forward(self, x, timesteps, context=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timesteps, context, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, transformer_options={}, input_ids=None, attention_mask=None, position_ids=None,
vinput_mask=None, ar_len=None, ref_pixel_values=None, ref_image_grid_thw=None, ref_patches=None, **kwargs):
"""Returns flow-match velocity (x - x_pred) / sigma"""
if input_ids is None or position_ids is None:
raise ValueError("HiDreamO1Transformer requires input_ids and position_ids in conditioning")
B, _, H, W = x.shape
h_p, w_p = H // self.patch_size, W // self.patch_size
tgt_image_len = h_p * w_p
z = einops.rearrange(
x, 'B C (H p1) (W p2) -> B (H W) (C p1 p2)',
p1=self.patch_size, p2=self.patch_size,
)
vinputs = torch.cat([z, ref_patches.to(z.dtype)], dim=1) if ref_patches is not None else z
inputs_embeds = self.language_model.embed_tokens(input_ids).to(x.dtype)
if ref_pixel_values is not None and ref_image_grid_thw is not None:
# ViT output is constant across sampling steps within a generation
# identity-key by the input tensor so refs don't recompute every step.
cached = self._visual_cache
if cached is not None and cached[0] is ref_pixel_values:
image_embeds = cached[1]
else:
ref_pv = ref_pixel_values.to(inputs_embeds.device)
ref_grid = ref_image_grid_thw.to(inputs_embeds.device).long()
# extra_conds wraps with a leading batch dim; refs are model-level so [0] always recovers them.
if ref_pv.dim() == 3:
ref_pv = ref_pv[0]
if ref_grid.dim() == 3:
ref_grid = ref_grid[0]
image_embeds = self.visual(ref_pv, ref_grid).to(inputs_embeds.dtype)
self._visual_cache = (ref_pixel_values, image_embeds)
# image_pad positions identical across batch (input_ids shared cond/uncond).
image_idx = (input_ids[0] == IMAGE_TOKEN_ID).nonzero(as_tuple=True)[0]
if image_idx.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Image-token count {image_idx.shape[0]} != ViT output count "
f"{image_embeds.shape[0]}; check tokenizer/processor alignment."
)
inputs_embeds[:, image_idx] = image_embeds.unsqueeze(0).expand(B, -1, -1)
sigma = timesteps.float() / 1000.0
t_pixeldit = 1.0 - sigma
t_emb = self.t_embedder1(t_pixeldit * 1000, inputs_embeds.dtype)
tms_mask_3d = (input_ids == self.tms_token_id).unsqueeze(-1).expand_as(inputs_embeds)
inputs_embeds = torch.where(tms_mask_3d, t_emb.unsqueeze(1).expand_as(inputs_embeds), inputs_embeds)
vinputs_embedded = self.x_embedder(vinputs.to(inputs_embeds.dtype))
inputs_embeds = torch.cat([inputs_embeds, vinputs_embedded], dim=1)
# extra_conds stores position_ids as (1, 3, T); process_cond repeats dim 0 to B. Take row 0.
freqs_cis = self.language_model.compute_freqs_cis(position_ids[0].to(x.device), x.device)
freqs_cis = tuple(t.to(x.dtype) for t in freqs_cis)
two_pass_attn = make_two_pass_attention(ar_len, transformer_options=transformer_options)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.language_model.layers)
transformer_options["block_type"] = "double"
# Cache prefix K/V across steps. Key includes input_ids (prompt), ref_id
# (refs scatter into inputs_embeds), and position_ids (RoPE baked into cached K).
can_cache = not blocks_replace and ar_len > 0
cache_len = ar_len if can_cache else 0
ref_id = id(ref_pixel_values) if ref_pixel_values is not None else None
pos_ids_key = position_ids[..., :cache_len] if can_cache else position_ids
cache_entries = self._kv_cache_entries
# Drop stale entries from a previous device (model was unloaded and reloaded).
if cache_entries and cache_entries[0]["input_ids"].device != input_ids.device:
cache_entries = []
self._kv_cache_entries = []
kv_cache = None
if can_cache:
for entry in cache_entries:
ck = entry["input_ids"]
ep = entry["position_ids"]
if (entry["cache_len"] == cache_len
and ck.shape == input_ids.shape and torch.equal(ck, input_ids)
and entry["ref_id"] == ref_id
and ep.shape == pos_ids_key.shape and torch.equal(ep, pos_ids_key)):
kv_cache = entry
break
if kv_cache is not None:
# Hot path: project Q/K/V only for fresh positions; past_key_value prepends cached AR K/V.
hidden_states = inputs_embeds[:, cache_len:]
sliced_freqs = tuple(t[..., cache_len:, :] for t in freqs_cis)
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
K_i, V_i = kv_cache["kv"][i]
hidden_states, _ = layer(
x=hidden_states, attention_mask=None, freqs_cis=sliced_freqs, optimized_attention=two_pass_attn,
past_key_value=(K_i, V_i, cache_len),
)
else:
# Cold path: run full sequence; if cacheable, snapshot K/V at AR positions.
snapshots = [] if can_cache else None
past_kv_cold = () if can_cache else None
hidden_states = inputs_embeds
for i, layer in enumerate(self.language_model.layers):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args, _layer=layer):
out = {}
out["x"], _ = _layer(
x=args["x"], attention_mask=args.get("attention_mask"),
freqs_cis=args["freqs_cis"], optimized_attention=args["optimized_attention"],
past_key_value=None,
)
return out
out = blocks_replace[("double_block", i)](
{"x": hidden_states, "attention_mask": None,
"freqs_cis": freqs_cis, "optimized_attention": two_pass_attn,
"transformer_options": transformer_options},
{"original_block": block_wrap},
)
hidden_states = out["x"]
else:
hidden_states, present_kv = layer(
x=hidden_states, attention_mask=None,
freqs_cis=freqs_cis, optimized_attention=two_pass_attn,
past_key_value=past_kv_cold,
)
if snapshots is not None:
K, V, _ = present_kv
snapshots.append((K[:, :, :cache_len].contiguous(),
V[:, :, :cache_len].contiguous()))
if snapshots is not None:
# Cap at 2 entries (cond + uncond). Multi-cond workflows LRU-evict.
new_entry = {
"input_ids": input_ids.clone(),
"cache_len": cache_len,
"kv": snapshots,
"ref_id": ref_id,
"position_ids": pos_ids_key.clone(),
}
self._kv_cache_entries = (cache_entries + [new_entry])[-2:]
if self.language_model.norm is not None:
hidden_states = self.language_model.norm(hidden_states)
# Slice target-image positions before the final projection so the Linear only runs on tgt_image_len tokens.
# In the hot path hidden_states starts at original position cache_len, so masks/indices shift by cache_len.
sliced_offset = cache_len if kv_cache is not None else 0
if vinput_mask is not None:
vmask = vinput_mask.to(x.device).bool()
if sliced_offset > 0:
vmask = vmask[:, sliced_offset:]
target_hidden = hidden_states[vmask].view(B, -1, hidden_states.shape[-1])[:, :tgt_image_len]
else:
txt_seq_len = input_ids.shape[1]
start = txt_seq_len - sliced_offset
target_hidden = hidden_states[:, start:start + tgt_image_len]
x_pred_tgt = self.final_layer2(target_hidden)
# fp32 final subtraction, bf16 here noticeably degrades samples.
x_pred_img = einops.rearrange(
x_pred_tgt, 'B (H W) (C p1 p2) -> B C (H p1) (W p2)',
H=h_p, W=w_p, p1=self.patch_size, p2=self.patch_size,
)
return (x.float() - x_pred_img.float()) / sigma.view(B, 1, 1, 1).clamp_min(1e-3)

View File

@ -0,0 +1,173 @@
"""HiDream-O1 input-prep helpers: image/resolution math and unified-sequence
RoPE position-id assembly. The fix_point offset in get_rope_index_fix_point
lets the target image and patchified ref images share spatial RoPE positions
despite living at different sequence indices same 2D image plane.
"""
import math
from typing import Optional
import torch
PATCH_SIZE = 32
CONDITION_IMAGE_SIZE = 384 # ViT-side base size for ref images
def resize_tensor(img_t, image_size, patch_size=16):
"""img_t: (1, 3, H, W) float [0, 1]. Fit to image_size**2 area, patch-aligned, center-cropped."""
while min(img_t.shape[-2], img_t.shape[-1]) >= 2 * image_size: # Pre-halves with 2x2 box averaging while the image is still very large
img_t = torch.nn.functional.avg_pool2d(img_t, kernel_size=2, stride=2)
_, _, height, width = img_t.shape
m = patch_size
s_max = image_size * image_size
scale = math.sqrt(s_max / (width * height))
candidates = [
(round(width * scale) // m * m, round(height * scale) // m * m),
(round(width * scale) // m * m, math.floor(height * scale) // m * m),
(math.floor(width * scale) // m * m, round(height * scale) // m * m),
(math.floor(width * scale) // m * m, math.floor(height * scale) // m * m),
]
candidates = sorted(candidates, key=lambda x: x[0] * x[1], reverse=True)
new_size = candidates[-1]
for c in candidates:
if c[0] * c[1] <= s_max:
new_size = c
break
new_w, new_h = new_size
s1 = width / new_w
s2 = height / new_h
if s1 < s2:
resize_w, resize_h = new_w, round(height / s1)
else:
resize_w, resize_h = round(width / s2), new_h
img_t = torch.nn.functional.interpolate(img_t, size=(resize_h, resize_w), mode="bicubic")
top = (resize_h - new_h) // 2
left = (resize_w - new_w) // 2
return img_t[..., top:top + new_h, left:left + new_w]
def calculate_dimensions(max_size, ratio):
"""(W, H) for an aspect ratio fitting in max_size**2 area, 32-aligned."""
width = math.sqrt(max_size * max_size * ratio)
height = width / ratio
width = int(width / 32) * 32
height = int(height / 32) * 32
return width, height
def ref_max_size(target_max_dim, k):
"""K-dependent ref-image max dim before patchifying."""
if k == 1:
return target_max_dim
if k == 2:
return target_max_dim * 48 // 64
if k <= 4:
return target_max_dim // 2
if k <= 8:
return target_max_dim * 24 // 64
return target_max_dim // 4
def cond_image_size(k):
"""K-dependent ViT-side image size."""
if k <= 4:
return CONDITION_IMAGE_SIZE
if k <= 8:
return CONDITION_IMAGE_SIZE * 48 // 64
return CONDITION_IMAGE_SIZE // 2
def get_rope_index_fix_point(
spatial_merge_size: int,
image_token_id: int,
vision_start_token_id: int,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
skip_vision_start_token=None,
fix_point: int = 4096,
):
mrope_position_deltas = []
if input_ids is not None and image_grid_thw is not None:
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3, input_ids.shape[0], input_ids.shape[1],
dtype=input_ids.dtype, device=input_ids.device,
)
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids_b in enumerate(total_input_ids):
fp = fix_point
image_index = 0
input_ids_b = input_ids_b[attention_mask[i] == 1]
vision_start_indices = torch.argwhere(input_ids_b == vision_start_token_id).squeeze(1)
vision_tokens = input_ids_b[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
input_tokens = input_ids_b.tolist()
llm_pos_ids_list = []
st = 0
remain_images = image_nums
for _ in range(image_nums):
if image_token_id in input_tokens and remain_images > 0:
ed = input_tokens.index(image_token_id, st)
else:
ed = len(input_tokens) + 1
t = image_grid_thw[image_index][0]
h = image_grid_thw[image_index][1]
w = image_grid_thw[image_index][2]
image_index += 1
remain_images -= 1
llm_grid_t = t.item()
llm_grid_h = h.item() // spatial_merge_size
llm_grid_w = w.item() // spatial_merge_size
text_len = ed - st
text_len -= skip_vision_start_token[image_index - 1]
text_len = max(0, text_len)
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
if skip_vision_start_token[image_index - 1]:
if fp > 0:
fp = fp - st_idx
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + fp + st_idx)
fp = 0
else:
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1).expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1], device=input_ids.device, dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas

View File

@ -22,26 +22,25 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
self.num_frames = n
self.data = tensor
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
self.patches_per_frame = patches_per_frame
self.num_frames = n // patches_per_frame
# All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.num_frames = n
self.data = tensor
def expand(self):
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# Used by compute_prompt_timestep and the audio cross-attention paths.
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, tracked_end:, :], k, v, heads,
attn_precision=attn_precision, transformer_options=transformer_options,
)
return out
class CrossAttention(nn.Module):
def __init__(
self,
@ -412,8 +467,10 @@ class CrossAttention(nn.Module):
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
@ -1063,7 +1120,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1099,12 +1158,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_mask = any(
e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
if not needs_mask:
return None
# Build per-guide-token weights for all tracked guide tokens.
@ -1159,16 +1218,11 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
# Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights == 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1234,45 +1288,6 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})

189
comfy/ldm/moge/geometry.py Normal file
View File

@ -0,0 +1,189 @@
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
from __future__ import annotations
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.optimize import least_squares
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: Optional[float] = None,
dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None) -> torch.Tensor:
"""Normalized view-plane UV coordinates with corners at +/-(W, H)/diagonal."""
if aspect_ratio is None:
aspect_ratio = width / height
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
span_y = 1.0 / (1 + aspect_ratio ** 2) ** 0.5
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
u, v = torch.meshgrid(u, v, indexing="xy")
return torch.stack([u, v], dim=-1)
def intrinsics_from_focal_center(fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor) -> torch.Tensor:
"""Assemble (..., 3, 3) intrinsics from broadcastable fx, fy, cx, cy."""
fx, fy, cx, cy = [torch.as_tensor(v) for v in (fx, fy, cx, cy)]
fx, fy, cx, cy = torch.broadcast_tensors(fx, fy, cx, cy)
zero = torch.zeros_like(fx)
one = torch.ones_like(fx)
return torch.stack([
torch.stack([fx, zero, cx], dim=-1),
torch.stack([zero, fy, cy], dim=-1),
torch.stack([zero, zero, one], dim=-1),
], dim=-2)
def depth_map_to_point_map(depth: torch.Tensor, intrinsics: torch.Tensor) -> torch.Tensor:
"""Back-project a (..., H, W) depth map through K^-1 to (..., H, W, 3) camera-space points.
Intrinsics use normalized image coords (x in [0, 1] left->right, y in [0, 1] top->bottom).
"""
H, W = depth.shape[-2:]
device, dtype = depth.device, depth.dtype
u = (torch.arange(W, dtype=dtype, device=device) + 0.5) / W
v = (torch.arange(H, dtype=dtype, device=device) + 0.5) / H
grid_v, grid_u = torch.meshgrid(v, u, indexing="ij")
pix = torch.stack([grid_u, grid_v, torch.ones_like(grid_u)], dim=-1)
K_inv = torch.linalg.inv(intrinsics)
rays = torch.einsum("...ij,hwj->...hwi", K_inv, pix)
return rays * depth.unsqueeze(-1)
def _solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray,
focal: Optional[float] = None) -> Tuple[float, float]:
"""LM-solve for z-shift; when focal is None, also recovers the optimal focal."""
uv = uv.reshape(-1, 2)
xy = xyz[..., :2].reshape(-1, 2)
z = xyz[..., 2].reshape(-1)
def fn(shift):
xy_proj = xy / (z + shift)[:, None]
f = focal if focal is not None else (xy_proj * uv).sum() / np.square(xy_proj).sum()
return (f * xy_proj - uv).ravel()
sol = least_squares(fn, x0=0.0, ftol=1e-3, method="lm")
shift = float(np.asarray(sol["x"]).squeeze())
if focal is None:
xy_proj = xy / (z + shift)[:, None]
focal = float((xy_proj * uv).sum() / np.square(xy_proj).sum())
return shift, focal
def recover_focal_shift(points: torch.Tensor, mask: Optional[torch.Tensor] = None,
focal: Optional[torch.Tensor] = None, downsample_size: Tuple[int, int] = (64, 64)
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Recover the focal length and z-shift that turn points into a metric point map.
Optical center is at the image center; returned focal is relative to half the image diagonal.
Returns (focal, shift) on the same device/dtype as points.
"""
shape = points.shape
H, W = shape[-3], shape[-2]
points_b = points.reshape(-1, H, W, 3)
mask_b = None if mask is None else mask.reshape(-1, H, W)
focal_b = None if focal is None else focal.reshape(-1)
uv = normalized_view_plane_uv(W, H, dtype=points.dtype, device=points.device)
points_lr = F.interpolate(points_b.permute(0, 3, 1, 2), downsample_size, mode="nearest").permute(0, 2, 3, 1)
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode="nearest").squeeze(0).permute(1, 2, 0)
mask_lr = None
if mask_b is not None:
mask_lr = F.interpolate(mask_b.to(torch.float32).unsqueeze(1), downsample_size, mode="nearest").squeeze(1) > 0
uv_np = uv_lr.detach().cpu().numpy()
points_np = points_lr.detach().cpu().numpy()
mask_np = None if mask_lr is None else mask_lr.detach().cpu().numpy()
focal_np = None if focal_b is None else focal_b.detach().cpu().numpy()
out_focal: list = []
out_shift: list = []
for i in range(points_b.shape[0]):
if mask_np is None:
xyz_i = points_np[i].reshape(-1, 3)
uv_i = uv_np.reshape(-1, 2)
else:
sel = mask_np[i]
if sel.sum() < 2:
out_focal.append(1.0)
out_shift.append(0.0)
continue
xyz_i = points_np[i][sel]
uv_i = uv_np[sel]
if focal_np is None:
shift_i, focal_i = _solve_optimal_shift(uv_i, xyz_i)
out_focal.append(focal_i)
else:
shift_i, _ = _solve_optimal_shift(uv_i, xyz_i, focal=float(focal_np[i]))
out_shift.append(shift_i)
shift_t = torch.tensor(out_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
if focal is None:
focal_t = torch.tensor(out_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
else:
focal_t = focal.reshape(shape[:-3])
return focal_t, shift_t
def depth_map_edge(depth: torch.Tensor, atol: Optional[float] = None, rtol: Optional[float] = None, kernel_size: int = 3) -> torch.Tensor:
"""Per-pixel boolean: True where the local depth window's max-min span exceeds atol or rtol*depth."""
shape = depth.shape
d = depth.reshape(-1, 1, *shape[-2:])
pad = kernel_size // 2
diff = F.max_pool2d(d, kernel_size, stride=1, padding=pad) + F.max_pool2d(-d, kernel_size, stride=1, padding=pad)
edge = torch.zeros_like(d, dtype=torch.bool)
if atol is not None:
edge |= diff > atol
if rtol is not None:
edge |= (diff / d.clamp_min(1e-6)).nan_to_num_() > rtol
return edge.reshape(*shape)
def triangulate_grid_mesh(points: torch.Tensor, mask: Optional[torch.Tensor] = None, decimation: int = 1, discontinuity_threshold: float = 0.04,
depth: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Triangulate a (H, W, 3) point map into (vertices, faces, uvs) on CPU.
Vertices: pixels with finite coords (passing optional mask). Quads with four valid corners
become two triangles. depth overrides the scalar used for the rtol edge check; pass radial
depth for panoramas (the default points[..., 2] goes negative below the equator).
"""
points = points.detach().cpu()
finite = torch.isfinite(points).all(dim=-1)
if mask is None:
mask = finite
else:
mask = mask.detach().cpu().to(torch.bool) & finite
if discontinuity_threshold > 0:
d = depth.detach().cpu() if depth is not None else points[..., 2]
# Replace inf with 0 so max-pool doesn't poison neighbourhoods (mask above already excludes those pixels).
d_finite = torch.where(finite, d, torch.zeros_like(d))
edge = depth_map_edge(d_finite, rtol=discontinuity_threshold)
mask = mask & ~edge
if decimation > 1:
points = points[::decimation, ::decimation].contiguous()
mask = mask[::decimation, ::decimation].contiguous()
H, W = points.shape[:2]
flat_mask = mask.reshape(-1)
idx = torch.full((H * W,), -1, dtype=torch.long)
n_valid = int(flat_mask.sum().item())
idx[flat_mask] = torch.arange(n_valid, dtype=torch.long)
idx = idx.reshape(H, W)
vertices = points.reshape(-1, 3)[flat_mask].contiguous()
yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing="ij")
u = xx.float() / max(W - 1, 1)
v = yy.float() / max(H - 1, 1)
uvs = torch.stack([u, v], dim=-1).reshape(-1, 2)[flat_mask].contiguous()
a, b, c, d = idx[:-1, :-1], idx[:-1, 1:], idx[1:, 1:], idx[1:, :-1]
quad_ok = (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0)
a, b, c, d = a[quad_ok], b[quad_ok], c[quad_ok], d[quad_ok]
faces = torch.cat([torch.stack([a, b, c], dim=-1), torch.stack([a, c, d], dim=-1)], dim=0).contiguous()
return vertices, faces, uvs

347
comfy/ldm/moge/model.py Normal file
View File

@ -0,0 +1,347 @@
"""MoGe v1 / v2 inference modules and a state-dict-driven builder.
V1: DINOv2 backbone + multi-output head (points, mask).
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
"""
from __future__ import annotations
from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
import comfy.model_management
import comfy.model_patcher
from comfy.image_encoders.dino2 import Dinov2Model
from .geometry import depth_map_to_point_map, intrinsics_from_focal_center, recover_focal_shift
from .modules import ConvStack, DINOv2Encoder, HeadV1, MLP, _view_plane_uv_grid
def _remap_points(points: torch.Tensor) -> torch.Tensor:
"""Apply the exp remap: z -> exp(z), xy stays linear and gets scaled by the new z."""
xy, z = points.split([2, 1], dim=-1)
z = torch.exp(z)
return torch.cat([xy * z, z], dim=-1)
def _detect_dinov2(sd: dict, prefix: str) -> Dict[str, Any]:
# All shipped MoGe checkpoints use plain DINOv2
hidden = sd[prefix + "embeddings.cls_token"].shape[-1]
layer_prefix = prefix + "encoder.layer."
depth = 1 + max(int(k[len(layer_prefix):].split(".")[0]) for k in sd if k.startswith(layer_prefix))
return {
"hidden_size": hidden,
"num_attention_heads": hidden // 64,
"num_hidden_layers": depth,
"layer_norm_eps": 1e-6,
"use_swiglu_ffn": False,
}
class MoGeModelV1(nn.Module):
"""MoGe v1: DINOv2 backbone + HeadV1 (points, mask)."""
image_mean: torch.Tensor
image_std: torch.Tensor
intermediate_layers = 4
num_tokens_range: Tuple[Number, Number] = (1200, 2500)
mask_threshold = 0.5
def __init__(self, backbone: Dict[str, Any], dim_upsample: List[int] = (256, 128, 128),
num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
self.backbone = Dinov2Model(backbone, dtype, device, operations)
self.head = HeadV1(dim_in=backbone["hidden_size"], dim_upsample=list(dim_upsample),
num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times_res_block_hidden,
dtype=dtype, device=device, operations=operations)
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
H, W = image.shape[-2:]
resize = ((num_tokens * 14 ** 2) / (H * W)) ** 0.5
rh, rw = int(H * resize), int(W * resize)
x = F.interpolate(image, (rh, rw), mode="bicubic", align_corners=False, antialias=True)
x = (x - self.image_mean) / self.image_std
x14 = F.interpolate(x, (rh // 14 * 14, rw // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
n_layers = len(self.backbone.encoder.layer)
indices = list(range(n_layers - self.intermediate_layers, n_layers))
feats = self.backbone.get_intermediate_layers(x14, indices, apply_norm=True)
points, mask = self.head(feats, x)
points = F.interpolate(points.float(), (H, W), mode="bilinear", align_corners=False)
points = _remap_points(points.permute(0, 2, 3, 1))
mask = F.interpolate(mask.float(), (H, W), mode="bilinear", align_corners=False).squeeze(1)
return {"points": points, "mask": mask}
@classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v1 head config from sd, build a model, and load weights."""
n_up = 1 + max(int(k.split(".")[2]) for k in sd if k.startswith("head.upsample_blocks."))
dim_upsample = [sd[f"head.upsample_blocks.{i}.0.0.weight"].shape[1] for i in range(n_up)]
# Each upsample stage is Sequential[upsampler, *res_blocks]; count res blocks at level 0.
num_res_blocks = max({int(k.split(".")[3]) for k in sd if k.startswith("head.upsample_blocks.0.")})
hidden_out = sd["head.upsample_blocks.0.1.layers.2.weight"].shape[0]
dim_times = max(hidden_out // dim_upsample[0], 1)
model = cls(backbone=_detect_dinov2(sd, prefix="backbone."),
dim_upsample=dim_upsample, num_res_blocks=num_res_blocks, dim_times_res_block_hidden=dim_times,
dtype=dtype, device=device, operations=operations)
model.load_state_dict(sd, strict=True)
return model
class MoGeModelV2(nn.Module):
"""MoGe v2: DINOv2 encoder + neck + per-output heads (points/mask/normal/metric-scale)."""
intermediate_layers = 4
num_tokens_range: Tuple[Number, Number] = (1200, 3600)
def __init__(self,
encoder: Dict[str, Any],
neck: Dict[str, Any],
points_head: Dict[str, Any],
mask_head: Dict[str, Any],
scale_head: Dict[str, Any],
normal_head: Optional[Dict[str, Any]] = None,
dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
self.encoder = DINOv2Encoder(**encoder, dtype=dtype, device=device, operations=operations)
self.neck = ConvStack(**neck, dtype=dtype, device=device, operations=operations)
self.points_head = ConvStack(**points_head, dtype=dtype, device=device, operations=operations)
self.mask_head = ConvStack(**mask_head, dtype=dtype, device=device, operations=operations)
self.scale_head = MLP(**scale_head, dtype=dtype, device=device, operations=operations)
if normal_head is not None:
self.normal_head = ConvStack(**normal_head, dtype=dtype, device=device, operations=operations)
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
B, _, H, W = image.shape
device, dtype = image.device, image.dtype
aspect_ratio = W / H
base_h = round((num_tokens / aspect_ratio) ** 0.5)
base_w = round((num_tokens * aspect_ratio) ** 0.5)
feat_top, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
# 5-level pyramid: feat at level 0 concatenated with UV, other levels UV-only.
levels = [_view_plane_uv_grid(B, base_h * (2 ** L), base_w * (2 ** L), aspect_ratio, dtype, device)
for L in range(5)]
levels[0] = torch.cat([feat_top, levels[0]], dim=1)
feats = self.neck(levels)
def _resize(v):
return F.interpolate(v, (H, W), mode="bilinear", align_corners=False)
points = _remap_points(_resize(self.points_head(feats)[-1]).permute(0, 2, 3, 1))
mask = _resize(self.mask_head(feats)[-1]).squeeze(1).sigmoid()
metric_scale = self.scale_head(cls_token).squeeze(1).exp()
result = {"points": points, "mask": mask, "metric_scale": metric_scale}
if hasattr(self, "normal_head"):
normal = _resize(self.normal_head(feats)[-1])
result["normal"] = F.normalize(normal.permute(0, 2, 3, 1), dim=-1)
return result
@classmethod
def from_state_dict(cls, sd, dtype=None, device=None, operations=comfy.ops.manual_cast):
"""Detect the v2 encoder/neck/heads config from sd, build a model, and load weights."""
backbone = _detect_dinov2(sd, prefix="encoder.backbone.")
depth = backbone["num_hidden_layers"]
n = cls.intermediate_layers
encoder = {
"backbone": backbone,
"intermediate_layers": [(depth // n) * (i + 1) - 1 for i in range(n)],
"dim_out": sd["encoder.output_projections.0.weight"].shape[0],
}
# scale_head is an MLP: Sequential of [Linear, ReLU, ..., Linear]; Linear weight is (out, in).
scale_idxs = sorted({int(k.split(".")[1]) for k in sd if k.startswith("scale_head.")})
scale_first = sd[f"scale_head.{scale_idxs[0]}.weight"]
cfg: Dict[str, Any] = {
"encoder": encoder,
"neck": cls._detect_convstack(sd, "neck."),
"points_head": cls._detect_convstack(sd, "points_head."),
"mask_head": cls._detect_convstack(sd, "mask_head."),
"scale_head": {"dims": [scale_first.shape[1]] + [sd[f"scale_head.{i}.weight"].shape[0] for i in scale_idxs]},
}
if any(k.startswith("normal_head.") for k in sd):
cfg["normal_head"] = cls._detect_convstack(sd, "normal_head.")
model = cls(**cfg, dtype=dtype, device=device, operations=operations)
model.load_state_dict(sd, strict=True)
return model
@staticmethod
def _detect_convstack(sd: dict, prefix: str) -> Dict[str, Any]:
"""Reconstruct a ConvStack config from the keys under prefix"""
in_keys = [k for k in sd if k.startswith(f"{prefix}input_blocks.") and k.endswith(".weight")]
n = 1 + max(int(k[len(f"{prefix}input_blocks."):].split(".")[0]) for k in in_keys)
in_shapes = [sd[f"{prefix}input_blocks.{i}.weight"].shape for i in range(n)]
has_out = lambda i: f"{prefix}output_blocks.{i}.weight" in sd
has_norm = f"{prefix}res_blocks.0.0.layers.0.weight" in sd
def num_res_at(i):
rb_prefix = f"{prefix}res_blocks.{i}."
return len({int(k[len(rb_prefix):].split(".")[0]) for k in sd if k.startswith(rb_prefix)})
return {
"dim_in": [s[1] for s in in_shapes],
"dim_res_blocks": [s[0] for s in in_shapes],
"dim_out": [sd[f"{prefix}output_blocks.{i}.weight"].shape[0] if has_out(i) else None for i in range(n)],
"num_res_blocks": [num_res_at(i) for i in range(n)],
"resamplers": ["conv_transpose" if f"{prefix}resamplers.{i}.0.weight" in sd else "bilinear"
for i in range(n - 1)],
"res_block_in_norm": "layer_norm" if has_norm else "none",
"res_block_hidden_norm": "group_norm" if has_norm else "none",
}
# Translate the Meta-style DINOv2 keys MoGe ships to the naming ComfyUI DINOv2 port expects,
# and split each fused qkv tensor into Q/K/V.
_DINOV2_TOPLEVEL_RENAMES = {
"patch_embed.proj.weight": "embeddings.patch_embeddings.projection.weight",
"patch_embed.proj.bias": "embeddings.patch_embeddings.projection.bias",
"cls_token": "embeddings.cls_token",
"pos_embed": "embeddings.position_embeddings",
"register_tokens": "embeddings.register_tokens",
"mask_token": "embeddings.mask_token",
"norm.weight": "layernorm.weight",
"norm.bias": "layernorm.bias",
}
_DINOV2_BLOCK_RENAMES = [
("ls1.gamma", "layer_scale1.lambda1"),
("ls2.gamma", "layer_scale2.lambda1"),
("attn.proj.", "attention.output.dense."),
("mlp.w12.", "mlp.weights_in."),
("mlp.w3.", "mlp.weights_out."),
]
def _remap_state_dict(sd: dict) -> dict:
if "model" in sd and "model_config" in sd:
sd = sd["model"]
prefix = "encoder.backbone." if any(k.startswith("encoder.backbone.") for k in sd) else "backbone."
out: dict = {}
for k, v in sd.items():
if not k.startswith(prefix):
out[k] = v
continue
rel = k[len(prefix):]
if rel in _DINOV2_TOPLEVEL_RENAMES:
out[prefix + _DINOV2_TOPLEVEL_RENAMES[rel]] = v
continue
if not rel.startswith("blocks."):
out[k] = v
continue
_, idx, sub = rel.split(".", 2)
if sub in ("attn.qkv.weight", "attn.qkv.bias"):
tail = sub.rsplit(".", 1)[1]
q, kw, vw = v.chunk(3, dim=0)
base = f"{prefix}encoder.layer.{idx}.attention.attention"
out[f"{base}.query.{tail}"] = q
out[f"{base}.key.{tail}"] = kw
out[f"{base}.value.{tail}"] = vw
continue
for old, new in _DINOV2_BLOCK_RENAMES:
sub = sub.replace(old, new)
out[f"{prefix}encoder.layer.{idx}.{sub}"] = v
return out
def build_from_state_dict(sd: dict, dtype=None, device=None, operations=comfy.ops.manual_cast) -> nn.Module:
"""Dispatch to v1 or v2 based on the DINOv2 backbone prefix."""
sd = _remap_state_dict(sd)
cls = MoGeModelV2 if any(k.startswith("encoder.backbone.") for k in sd) else MoGeModelV1
return cls.from_state_dict(sd, dtype=dtype, device=device, operations=operations)
class MoGeModel:
"""Loaded MoGe model + ComfyUI memory management."""
def __init__(self, state_dict: dict):
# text encoder dtype closest match
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = build_from_state_dict(state_dict, dtype=self.dtype, device=offload_device, operations=comfy.ops.manual_cast).eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.version = "v2" if hasattr(self.model, "encoder") else "v1"
self.mask_threshold = float(getattr(self.model, "mask_threshold", 0.5))
nt = getattr(self.model, "num_tokens_range", (1200, 2500 if self.version == "v1" else 3600))
self.num_tokens_range = (int(nt[0]), int(nt[1]))
def infer(self, image: torch.Tensor, num_tokens: Optional[int] = None,
resolution_level: int = 9, fov_x: Optional[Union[Number, torch.Tensor]] = None,
force_projection: bool = True, apply_mask: bool = True,
apply_metric_scale: bool = True
) -> Dict[str, torch.Tensor]:
"""Run a single MoGe forward + post-process pass. image is (B, 3, H, W) in [0, 1]."""
comfy.model_management.load_model_gpu(self.patcher)
image = image.to(device=self.load_device, dtype=self.dtype)
H, W = image.shape[-2:]
aspect_ratio = W / H
if num_tokens is None:
lo, hi = self.num_tokens_range
num_tokens = int(lo + (resolution_level / 9) * (hi - lo))
out = self.model.forward(image, num_tokens=num_tokens)
points = out["points"].float() # recover_focal_shift goes through scipy on CPU; needs fp32.
mask_binary = out["mask"] > self.mask_threshold
normal = out.get("normal")
metric_scale = out.get("metric_scale")
diag = (1 + aspect_ratio ** 2) ** 0.5
def focal_from_fov_deg(deg):
fov = torch.as_tensor(deg, device=points.device, dtype=points.dtype)
return aspect_ratio / diag / torch.tan(torch.deg2rad(fov / 2))
if fov_x is None:
focal, shift = recover_focal_shift(points, mask_binary)
# Fall back to 60 deg FoV when the least-squares solver flips the focal sign.
bad = ~torch.isfinite(focal) | (focal <= 0)
if bool(bad.any()):
focal = torch.where(bad, focal_from_fov_deg(60.0), focal)
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
else:
focal = focal_from_fov_deg(fov_x).expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
f_diag = focal / 2 * diag
half = torch.tensor(0.5, device=points.device, dtype=points.dtype)
intrinsics = intrinsics_from_focal_center(f_diag / aspect_ratio, f_diag, half, half)
points[..., 2] = points[..., 2] + shift[..., None, None]
# v2 only: filter mask by depth>0 to drop metric-scale negative-depth artifacts.
if self.version == "v2":
mask_binary = mask_binary & (points[..., 2] > 0)
depth = points[..., 2].clone()
if force_projection:
points = depth_map_to_point_map(depth, intrinsics=intrinsics)
if apply_metric_scale and metric_scale is not None:
points = points * metric_scale[:, None, None, None]
depth = depth * metric_scale[:, None, None]
if apply_mask:
points = torch.where(mask_binary[..., None], points, torch.full_like(points, float("inf")))
depth = torch.where(mask_binary, depth, torch.full_like(depth, float("inf")))
if normal is not None:
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal))
result = {"points": points, "depth": depth, "intrinsics": intrinsics, "mask": mask_binary}
if normal is not None:
result["normal"] = normal
return result

204
comfy/ldm/moge/modules.py Normal file
View File

@ -0,0 +1,204 @@
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.image_encoders.dino2 import Dinov2Model
from .geometry import normalized_view_plane_uv
def _conv2d(operations, c_in: int, c_out: int, k: int = 3, *, dtype=None, device=None):
return operations.Conv2d(c_in, c_out, kernel_size=k, padding=k // 2, padding_mode="replicate", dtype=dtype, device=device)
def _view_plane_uv_grid(batch: int, height: int, width: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
"""Batched normalized view-plane UV grid as a (B, 2, H, W) tensor."""
uv = normalized_view_plane_uv(width, height, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
return uv.permute(2, 0, 1).unsqueeze(0).expand(batch, -1, -1, -1)
def _concat_view_plane_uv(x: torch.Tensor, aspect_ratio: float) -> torch.Tensor:
"""Append a 2-channel normalized view-plane UV grid to x along the channel dim."""
uv = _view_plane_uv_grid(x.shape[0], x.shape[-2], x.shape[-1], aspect_ratio, x.dtype, x.device)
return torch.cat([x, uv], dim=1)
class ResidualConvBlock(nn.Module):
def __init__(self, channels: int, hidden_channels: Optional[int] = None, in_norm: str = "layer_norm", hidden_norm: str = "group_norm",
dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
hidden_channels = hidden_channels if hidden_channels is not None else channels
in_norm_layer = operations.GroupNorm(1, channels, dtype=dtype, device=device) if in_norm == "layer_norm" else nn.Identity()
hidden_norm_layer = (operations.GroupNorm(max(hidden_channels // 32, 1), hidden_channels, dtype=dtype, device=device)
if hidden_norm == "group_norm" else nn.Identity())
self.layers = nn.Sequential(
in_norm_layer, nn.ReLU(), _conv2d(operations, channels, hidden_channels, dtype=dtype, device=device),
hidden_norm_layer, nn.ReLU(), _conv2d(operations, hidden_channels, channels, dtype=dtype, device=device),
)
def forward(self, x):
return self.layers(x) + x
class Resampler(nn.Sequential):
"""2x upsampler: ConvTranspose2d(2x2) or bilinear upsample, followed by a 3x3 conv."""
def __init__(self, in_channels: int, out_channels: int, type_: str, dtype=None, device=None, operations=comfy.ops.manual_cast):
if type_ == "conv_transpose":
up = operations.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, dtype=dtype, device=device)
conv_in = out_channels
else: # "bilinear"
up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False)
conv_in = in_channels
super().__init__(up, _conv2d(operations, conv_in, out_channels, dtype=dtype, device=device))
class MLP(nn.Sequential):
def __init__(self, dims: Sequence[int], dtype=None, device=None, operations=comfy.ops.manual_cast):
layers = []
for d_in, d_out in zip(dims[:-2], dims[1:-1]):
layers.append(operations.Linear(d_in, d_out, dtype=dtype, device=device))
layers.append(nn.ReLU(inplace=True))
layers.append(operations.Linear(dims[-2], dims[-1], dtype=dtype, device=device))
super().__init__(*layers)
class ConvStack(nn.Module):
def __init__(self, dim_in: List[Optional[int]], dim_res_blocks: List[int], dim_out: List[Optional[int]], resamplers: List[str],
num_res_blocks: List[int], dim_times_res_block_hidden: int = 1, res_block_in_norm: str = "layer_norm", res_block_hidden_norm: str = "group_norm",
dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
self.input_blocks = nn.ModuleList([
(_conv2d(operations, d_in, d_res, k=1, dtype=dtype, device=device)
if d_in is not None else nn.Identity())
for d_in, d_res in zip(dim_in, dim_res_blocks)
])
self.resamplers = nn.ModuleList([
Resampler(prev, succ, type_=r, dtype=dtype, device=device, operations=operations)
for prev, succ, r in zip(dim_res_blocks[:-1], dim_res_blocks[1:], resamplers)
])
self.res_blocks = nn.ModuleList([
nn.Sequential(*[
ResidualConvBlock(d_res, dim_times_res_block_hidden * d_res, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm, dtype=dtype, device=device, operations=operations)
for _ in range(num_res_blocks[i])
])
for i, d_res in enumerate(dim_res_blocks)
])
self.output_blocks = nn.ModuleList([
(_conv2d(operations, d_res, d_out, k=1, dtype=dtype, device=device)
if d_out is not None else nn.Identity())
for d_out, d_res in zip(dim_out, dim_res_blocks)
])
def forward(self, in_features: List[Optional[torch.Tensor]]):
out_features = []
x = None
for i in range(len(self.res_blocks)):
feat = self.input_blocks[i](in_features[i]) if in_features[i] is not None else None
if i == 0:
x = feat
elif feat is not None:
x = x + feat
x = self.res_blocks[i](x)
out_features.append(self.output_blocks[i](x))
if i < len(self.res_blocks) - 1:
x = self.resamplers[i](x)
return out_features
class DINOv2Encoder(nn.Module):
"""Comfy DINOv2 backbone with per-layer 1x1 projection heads."""
def __init__(self, backbone: dict, intermediate_layers: List[int], dim_out: int, dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
self.intermediate_layers = list(intermediate_layers)
dim_features = backbone["hidden_size"]
self.backbone = Dinov2Model(backbone, dtype, device, operations)
self.output_projections = nn.ModuleList([
_conv2d(operations, dim_features, dim_out, k=1, dtype=dtype, device=device)
for _ in range(len(self.intermediate_layers))
])
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def forward(self, image: torch.Tensor, token_rows: int, token_cols: int,
return_class_token: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=True)
image_14 = (image_14 - self.image_mean) / self.image_std
feats = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, apply_norm=True)
x = torch.stack([
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
for proj, (feat, _cls) in zip(self.output_projections, feats)
], dim=1).sum(dim=1)
if return_class_token:
return x, feats[-1][1]
return x
class HeadV1(nn.Module):
"""v1 head: 4 backbone-feature projections -> shared upsample stack -> per-target output convs (points, mask)."""
NUM_FEATURES = 4
DIM_PROJ = 512
DIM_OUT = (3, 1) # 3 channels for points, 1 for mask
LAST_CONV_CHANNELS = 32
def __init__(self, dim_in: int, dim_upsample: List[int] = (256, 128, 128), num_res_blocks: int = 1, dim_times_res_block_hidden: int = 1,
dtype=None, device=None, operations=comfy.ops.manual_cast):
super().__init__()
self.projects = nn.ModuleList([
_conv2d(operations, dim_in, self.DIM_PROJ, k=1, dtype=dtype, device=device)
for _ in range(self.NUM_FEATURES)
])
def upsampler(in_ch, out_ch):
return nn.Sequential(
operations.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2, dtype=dtype, device=device),
_conv2d(operations, out_ch, out_ch, dtype=dtype, device=device),
)
in_chs = [self.DIM_PROJ] + list(dim_upsample[:-1])
self.upsample_blocks = nn.ModuleList([
nn.Sequential(
upsampler(in_ch + 2, out_ch),
*(ResidualConvBlock(out_ch, dim_times_res_block_hidden * out_ch, dtype=dtype, device=device, operations=operations)
for _ in range(num_res_blocks))
)
for in_ch, out_ch in zip(in_chs, dim_upsample)
])
self.output_block = nn.ModuleList([
nn.Sequential(
_conv2d(operations, dim_upsample[-1] + 2, self.LAST_CONV_CHANNELS, dtype=dtype, device=device),
nn.ReLU(inplace=True),
_conv2d(operations, self.LAST_CONV_CHANNELS, d_out, k=1, dtype=dtype, device=device),
)
for d_out in self.DIM_OUT
])
def forward(self, hidden_states, image: torch.Tensor):
img_h, img_w = image.shape[-2:]
patch_h, patch_w = img_h // 14, img_w // 14
aspect = img_w / img_h
x = torch.stack([
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
for proj, (feat, _cls) in zip(self.projects, hidden_states)
], dim=1).sum(dim=1)
for block in self.upsample_blocks:
x = block(_concat_view_plane_uv(x, aspect))
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
x = _concat_view_plane_uv(x, aspect)
return [block(x) for block in self.output_block]

313
comfy/ldm/moge/panorama.py Normal file
View File

@ -0,0 +1,313 @@
"""Panorama (equirectangular) inference helpers for MoGe.
Splits an equirect into 12 perspective views via an icosahedron camera rig, runs
the model per view, and stitches per-view distance maps back into a single
equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
"""
from __future__ import annotations
from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from scipy.ndimage import convolve, map_coordinates
from scipy.sparse import vstack, csr_array
from scipy.sparse.linalg import lsmr
def _icosahedron_directions() -> np.ndarray:
"""12 icosahedron-vertex directions (non-normalised, matching upstream's vertex order)."""
A = (1.0 + np.sqrt(5.0)) / 2.0
return np.array([
[0, 1, A], [0, -1, A], [0, 1, -A], [0, -1, -A],
[1, A, 0], [-1, A, 0], [1, -A, 0], [-1, -A, 0],
[A, 0, 1], [A, 0, -1], [-A, 0, 1], [-A, 0, -1],
], dtype=np.float32)
def _intrinsics_from_fov(fov_x_rad: float, fov_y_rad: float) -> np.ndarray:
"""Normalised-image (unit-square) K matrix."""
fx = 0.5 / np.tan(fov_x_rad / 2)
fy = 0.5 / np.tan(fov_y_rad / 2)
return np.array([[fx, 0, 0.5], [0, fy, 0.5], [0, 0, 1]], dtype=np.float32)
def _extrinsics_look_at(eye: np.ndarray, target: np.ndarray, up: np.ndarray) -> np.ndarray:
"""OpenCV-convention world->camera extrinsics for an array of look-at targets (N, 4, 4)."""
eye = np.asarray(eye, dtype=np.float32)
target = np.asarray(target, dtype=np.float32)
up = np.asarray(up, dtype=np.float32)
if target.ndim == 1:
target = target[None]
fwd = target - eye
fwd = fwd / np.linalg.norm(fwd, axis=-1, keepdims=True).clip(1e-12)
right = np.cross(fwd, up)
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
# Fall back to an arbitrary perpendicular if forward is parallel to up.
parallel = right_norm.squeeze(-1) < 1e-6
if parallel.any():
alt_up = np.array([1, 0, 0], dtype=np.float32)
right = np.where(parallel[:, None], np.cross(fwd, alt_up), right)
right_norm = np.linalg.norm(right, axis=-1, keepdims=True)
right = right / right_norm.clip(1e-12)
new_up = np.cross(fwd, right)
R = np.stack([right, new_up, fwd], axis=-2)
t = -np.einsum("nij,j->ni", R, eye)
E = np.zeros((R.shape[0], 4, 4), dtype=np.float32)
E[:, :3, :3] = R
E[:, :3, 3] = t
E[:, 3, 3] = 1.0
return E
def get_panorama_cameras() -> Tuple[np.ndarray, List[np.ndarray]]:
"""Returns (extrinsics (12, 4, 4), [intrinsics] * 12) for icosahedron views at 90 deg FoV."""
targets = _icosahedron_directions()
eye = np.zeros(3, dtype=np.float32)
up = np.array([0, 0, 1], dtype=np.float32)
extrinsics = _extrinsics_look_at(eye, targets, up)
K = _intrinsics_from_fov(np.deg2rad(90.0), np.deg2rad(90.0))
return extrinsics, [K] * len(targets)
def spherical_uv_to_directions(uv: np.ndarray) -> np.ndarray:
"""Equirect UV in [0, 1] -> 3D unit-direction (Z up)."""
theta = (1 - uv[..., 0]) * (2 * np.pi)
phi = uv[..., 1] * np.pi
return np.stack([
np.sin(phi) * np.cos(theta),
np.sin(phi) * np.sin(theta),
np.cos(phi),
], axis=-1).astype(np.float32)
def directions_to_spherical_uv(directions: np.ndarray) -> np.ndarray:
"""3D direction -> equirect UV in [0, 1]."""
n = np.linalg.norm(directions, axis=-1, keepdims=True).clip(1e-12)
d = directions / n
u = 1 - np.arctan2(d[..., 1], d[..., 0]) / (2 * np.pi) % 1.0
v = np.arccos(d[..., 2].clip(-1, 1)) / np.pi
return np.stack([u, v], axis=-1).astype(np.float32)
def _uv_grid(H: int, W: int) -> np.ndarray:
"""Pixel-center UV grid in [0, 1]; (H, W, 2)."""
u = (np.arange(W, dtype=np.float32) + 0.5) / W
v = (np.arange(H, dtype=np.float32) + 0.5) / H
return np.stack(np.meshgrid(u, v, indexing="xy"), axis=-1)
def _unproject_cv(uv: np.ndarray, depth: np.ndarray,
extrinsics: np.ndarray, intrinsics: np.ndarray) -> np.ndarray:
"""Back-project pixels into world coords (OpenCV convention)."""
pix = np.concatenate([uv, np.ones_like(uv[..., :1])], axis=-1)
K_inv = np.linalg.inv(intrinsics)
cam = pix @ K_inv.T * depth[..., None]
cam_h = np.concatenate([cam, np.ones_like(cam[..., :1])], axis=-1)
E_inv = np.linalg.inv(extrinsics)
return (cam_h @ E_inv.T)[..., :3]
def _project_cv(points: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""World coords -> (uv, depth) in the camera (OpenCV convention)."""
pts_h = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1)
cam = pts_h @ extrinsics.T
cam_xyz = cam[..., :3]
depth = cam_xyz[..., 2]
proj = cam_xyz @ intrinsics.T
uv = proj[..., :2] / proj[..., 2:3].clip(1e-12)
return uv.astype(np.float32), depth.astype(np.float32)
def _grid_sample_uv(img_bchw: torch.Tensor, uv: torch.Tensor, mode: str = "bilinear") -> torch.Tensor:
"""Sample img_bchw at UV-in-[0,1] coords uv of shape (B, H, W, 2); replicate-border."""
grid = uv * 2.0 - 1.0
return F.grid_sample(img_bchw, grid, mode=mode, padding_mode="border", align_corners=False)
def split_panorama_image(image: torch.Tensor, extrinsics: np.ndarray, intrinsics: List[np.ndarray], resolution: int) -> torch.Tensor:
"""(3, Hp, Wp) equirect on any device -> (N, 3, R, R) perspective crops on the same device."""
device = image.device
N = len(extrinsics)
uv = _uv_grid(resolution, resolution)
sample_uvs = []
for i in range(N):
world = _unproject_cv(uv, np.ones(uv.shape[:-1], dtype=np.float32), extrinsics[i], intrinsics[i])
sample_uvs.append(directions_to_spherical_uv(world))
sample_uvs = np.stack(sample_uvs, axis=0)
img_bchw = image.unsqueeze(0).expand(N, -1, -1, -1).contiguous()
sample_uvs_t = torch.from_numpy(sample_uvs).to(device=device, dtype=image.dtype)
return _grid_sample_uv(img_bchw, sample_uvs_t, mode="bilinear")
def _poisson_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse Laplacian operator over the H x W grid."""
grid_index = np.arange(H * W).reshape(H, W)
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode="wrap" if wrap_x else "edge")
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode="wrap" if wrap_y else "edge")
data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(H * W, axis=0).reshape(-1)
indices = np.stack([
grid_index[1:-1, 1:-1],
grid_index[:-2, 1:-1], grid_index[2:, 1:-1],
grid_index[1:-1, :-2], grid_index[1:-1, 2:],
], axis=-1).reshape(-1)
indptr = np.arange(0, H * W * 5 + 1, 5)
return csr_array((data, indices, indptr), shape=(H * W, H * W))
def _grad_equation(W: int, H: int, wrap_x: bool = False, wrap_y: bool = False):
"""Sparse forward-difference operator over the H x W grid."""
grid_index = np.arange(W * H).reshape(H, W)
if wrap_x:
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode="wrap")
if wrap_y:
grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode="wrap")
data = np.concatenate([
np.concatenate([
np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
-np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1),
], axis=1).reshape(-1),
np.concatenate([
np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
-np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1),
], axis=1).reshape(-1),
])
indices = np.concatenate([
np.concatenate([grid_index[:, :-1].reshape(-1, 1), grid_index[:, 1:].reshape(-1, 1)], axis=1).reshape(-1),
np.concatenate([grid_index[:-1, :].reshape(-1, 1), grid_index[1:, :].reshape(-1, 1)], axis=1).reshape(-1),
])
nx = grid_index.shape[0] * (grid_index.shape[1] - 1)
ny = (grid_index.shape[0] - 1) * grid_index.shape[1]
indptr = np.arange(0, nx * 2 + ny * 2 + 1, 2)
return csr_array((data, indices, indptr), shape=(nx + ny, H * W))
def _scipy_remap_bilinear(img: np.ndarray, sample_pixels: np.ndarray, mode: str = "bilinear") -> np.ndarray:
"""Bilinear/nearest sampling at fractional pixel coords; out-of-range clamps to nearest border."""
H, W = img.shape[:2]
yy = np.clip(sample_pixels[..., 1], 0, H - 1)
xx = np.clip(sample_pixels[..., 0], 0, W - 1)
order = 1 if mode == "bilinear" else 0
if img.ndim == 2:
return map_coordinates(img, [yy, xx], order=order, mode="nearest").astype(img.dtype)
out = np.stack([
map_coordinates(img[..., c], [yy, xx], order=order, mode="nearest")
for c in range(img.shape[-1])
], axis=-1)
return out.astype(img.dtype)
def merge_panorama_depth(width: int, height: int,
distance_maps: List[np.ndarray], pred_masks: List[np.ndarray],
extrinsics: List[np.ndarray], intrinsics: List[np.ndarray],
on_view: Optional[Callable[[], None]] = None,
on_solve_start: Optional[Callable[[int, int], None]] = None,
on_solve_end: Optional[Callable[[int, int], None]] = None,
) -> Tuple[np.ndarray, np.ndarray]:
"""Stitch per-view distance maps into a single equirect distance map.
Recursive multi-scale solve: solves at half resolution first and uses that as the lsmr init
for the full-resolution solve. Optional callbacks fire per view processed and around each
lsmr solve so callers can drive a progress bar.
"""
if max(width, height) > 256:
coarse_depth, _ = merge_panorama_depth(width // 2, height // 2,
distance_maps, pred_masks, extrinsics, intrinsics,
on_view=on_view,
on_solve_start=on_solve_start,
on_solve_end=on_solve_end)
t = torch.from_numpy(coarse_depth).unsqueeze(0).unsqueeze(0)
t = F.interpolate(t, size=(height, width), mode="bilinear", align_corners=False)
depth_init = t.squeeze().numpy().astype(np.float32)
else:
depth_init = None
spherical_directions = spherical_uv_to_directions(_uv_grid(height, width))
pano_log_grad_maps, pano_grad_masks = [], []
pano_log_lap_maps, pano_lap_masks = [], []
pano_pred_masks: List[np.ndarray] = []
for i in range(len(distance_maps)):
proj_uv, proj_depth = _project_cv(spherical_directions, extrinsics[i], intrinsics[i])
proj_valid = (proj_depth > 0) & (proj_uv > 0).all(axis=-1) & (proj_uv < 1).all(axis=-1)
Hd, Wd = distance_maps[i].shape[:2]
proj_pixels = np.clip(proj_uv, 0, 1) * np.array([Wd - 1, Hd - 1], dtype=np.float32)
log_dist = np.log(np.clip(distance_maps[i], 1e-6, None))
sampled = _scipy_remap_bilinear(log_dist, proj_pixels, mode="bilinear")
pano_log = np.where(proj_valid, sampled, 0.0).astype(np.float32)
sampled_mask = _scipy_remap_bilinear(pred_masks[i].astype(np.uint8), proj_pixels, mode="nearest")
pano_pred = proj_valid & (sampled_mask > 0)
# Equirect wraps horizontally but not vertically: wrap pad along x, edge pad along y.
padded = np.pad(pano_log, ((0, 0), (0, 1)), mode="wrap")
gx, gy = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
padded_m = np.pad(pano_pred, ((0, 0), (0, 1)), mode="wrap")
mx, my = padded_m[:, :-1] & padded_m[:, 1:], padded_m[:-1, :] & padded_m[1:, :]
pano_log_grad_maps.append((gx, gy))
pano_grad_masks.append((mx, my))
padded = np.pad(pano_log, ((1, 1), (0, 0)), mode="edge")
padded = np.pad(padded, ((0, 0), (1, 1)), mode="wrap")
lap_kernel = np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32)
lap = convolve(padded, lap_kernel)[1:-1, 1:-1]
padded_m = np.pad(pano_pred, ((1, 1), (0, 0)), mode="edge")
padded_m = np.pad(padded_m, ((0, 0), (1, 1)), mode="wrap")
m_kernel = np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8)
lap_mask = convolve(padded_m.astype(np.uint8), m_kernel)[1:-1, 1:-1] == 5
pano_log_lap_maps.append(lap)
pano_lap_masks.append(lap_mask)
pano_pred_masks.append(pano_pred)
if on_view is not None:
on_view()
gx = np.stack([m[0] for m in pano_log_grad_maps], axis=0)
gy = np.stack([m[1] for m in pano_log_grad_maps], axis=0)
mx = np.stack([m[0] for m in pano_grad_masks], axis=0)
my = np.stack([m[1] for m in pano_grad_masks], axis=0)
gx_avg = (gx * mx).sum(axis=0) / mx.sum(axis=0).clip(1e-3)
gy_avg = (gy * my).sum(axis=0) / my.sum(axis=0).clip(1e-3)
laps = np.stack(pano_log_lap_maps, axis=0)
lap_masks = np.stack(pano_lap_masks, axis=0)
lap_avg = (laps * lap_masks).sum(axis=0) / lap_masks.sum(axis=0).clip(1e-3)
grad_x_mask = mx.any(axis=0).reshape(-1)
grad_y_mask = my.any(axis=0).reshape(-1)
grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
lap_mask_flat = lap_masks.any(axis=0).reshape(-1)
A = vstack([
_grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
_poisson_equation(width, height, wrap_x=True, wrap_y=False)[lap_mask_flat],
])
b = np.concatenate([
gx_avg.reshape(-1)[grad_x_mask],
gy_avg.reshape(-1)[grad_y_mask],
lap_avg.reshape(-1)[lap_mask_flat],
])
x0 = np.log(np.clip(depth_init, 1e-6, None)).reshape(-1) if depth_init is not None else None
if on_solve_start is not None:
on_solve_start(width, height)
x, *_ = lsmr(A, b, atol=1e-5, btol=1e-5, x0=x0, show=False)
if on_solve_end is not None:
on_solve_end(width, height)
pano_depth = np.exp(x).reshape(height, width).astype(np.float32)
pano_mask = np.any(pano_pred_masks, axis=0)
return pano_depth, pano_mask

View File

@ -97,12 +97,14 @@ def load_lora(lora, to_load, log_missing=True):
def model_lora_keys_clip(model, key_map={}):
sdk = model.state_dict().keys()
prefix_set = set()
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
prefix_set.add(k.split('.')[0])
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False
@ -163,6 +165,13 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
if len(prefix_set) == 1:
full_prefix = "{}.transformer.model.".format(next(iter(prefix_set))) # kohya anima and maybe other single TE models that use a single llama arch based te
for k in sdk:
if k.endswith(".weight"):
if k.startswith(full_prefix):
l_key = k[len(full_prefix):-len(".weight")]
key_map["lora_te_{}".format(l_key.replace(".", "_"))] = k
k = "clip_g.transformer.text_projection.weight"
if k in sdk:

View File

@ -58,6 +58,8 @@ import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.model_management
import comfy.patcher_extension
@ -1674,6 +1676,39 @@ class HiDream(BaseModel):
out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
return out
class HiDreamO1(BaseModel):
"""HiDream-O1-Image: pixel-space DiT (no VAE). Refs from HiDreamO1ReferenceImages and tokens from the stub TE flow through
extra_conds; the heavy preprocessing lives in comfy.ldm.hidream_o1.conditioning."""
PATCH_SIZE = 32
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream_o1.model.HiDreamO1Transformer)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
text_input_ids = kwargs.get("text_input_ids", None)
noise = kwargs.get("noise", None)
if text_input_ids is None or noise is None:
return out
# handle area conds
area = kwargs.get("area", None)
if area is not None:
crop_h = min(noise.shape[-2] - area[2], area[0])
crop_w = min(noise.shape[-1] - area[3], area[1])
noise = torch.empty((noise.shape[0], 3, crop_h, crop_w), dtype=noise.dtype, device=noise.device)
conds = build_extra_conds(
text_input_ids, noise,
ref_images=kwargs.get("reference_latents", None),
target_patch_size=self.PATCH_SIZE,
)
for k, v in conds.items():
# ar_len is a Python int (precomputed to avoid a GPU sync in forward).
cls = comfy.conds.CONDConstant if k == "ar_len" else comfy.conds.CONDRegular
out[k] = cls(v)
return out
class Chroma(Flux):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.chroma.model.Chroma):
super().__init__(model_config, model_type, device=device, unet_model=unet_model)

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@ -242,6 +242,37 @@ class LazyCastingParam(torch.nn.Parameter):
return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
class LazyCastingQuantizedParam:
def __init__(self, model, key):
self.model = model
self.key = key
self.cpu_state_dict = None
def state_dict_tensor(self, state_dict_key):
if self.cpu_state_dict is None:
weight = self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True)
self.cpu_state_dict = {k: v.to("cpu") for k, v in weight.state_dict(self.key).items()}
return self.cpu_state_dict[state_dict_key]
class LazyCastingParamPiece(torch.nn.Parameter):
def __new__(cls, caster, state_dict_key, tensor):
return super().__new__(cls, tensor)
def __init__(self, caster, state_dict_key, tensor):
self.caster = caster
self.state_dict_key = state_dict_key
@property
def device(self):
return CustomTorchDevice
def to(self, *args, **kwargs):
caster = self.caster
del self.caster
return caster.state_dict_tensor(self.state_dict_key)
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@ -1463,20 +1494,37 @@ class ModelPatcher:
self.clear_cached_hook_weights()
def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
original_state_dict = self.model.diffusion_model.state_dict()
unet_state_dict = {}
keys = list(original_state_dict)
while len(keys) > 0:
k = keys.pop(0)
v = original_state_dict[k]
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
unet_state_dict[k] = v
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
except:
unet_state_dict[k] = v
continue
if not op or not hasattr(op, "comfy_cast_weights") or \
(hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
unet_state_dict[k] = v
continue
key = "diffusion_model." + k
unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
weight = comfy.utils.get_attr(self.model, key)
if isinstance(weight, QuantizedTensor) and k in original_state_dict:
qt_state_dict = weight.state_dict(k)
caster = LazyCastingQuantizedParam(self, key)
for group_key in (x for x in qt_state_dict if x in original_state_dict):
if group_key in keys:
keys.remove(group_key)
unet_state_dict.pop(group_key, "")
unet_state_dict[group_key] = LazyCastingParamPiece(caster, "diffusion_model." + group_key, original_state_dict[group_key])
continue
unet_state_dict[k] = LazyCastingParam(self, key, weight)
return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
def __del__(self):

View File

@ -93,7 +93,8 @@ class CONST:
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
return sigma * noise + (1.0 - sigma) * latent_image
s = getattr(self, "noise_scale", 1.0)
return sigma * (s * noise) + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
sigma = reshape_sigma(sigma, latent.ndim)
@ -288,7 +289,11 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0), multiplier=sampling_settings.get("multiplier", 1000))
self.set_noise_scale(sampling_settings.get("noise_scale", 1.0))
self.set_parameters(
shift=sampling_settings.get("shift", 1.0),
multiplier=sampling_settings.get("multiplier", 1000),
)
def set_parameters(self, shift=1.0, timesteps=1000, multiplier=1000):
self.shift = shift
@ -296,6 +301,9 @@ class ModelSamplingDiscreteFlow(torch.nn.Module):
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps) * multiplier)
self.register_buffer('sigmas', ts)
def set_noise_scale(self, noise_scale):
self.noise_scale = float(noise_scale)
@property
def sigma_min(self):
return self.sigmas[0]

View File

@ -1285,7 +1285,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
@ -1375,6 +1376,7 @@ def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
logging.info("Native ops: {} {}".format(", ".join(QUANT_ALGOS.keys() - disabled), ", emulated ops: {}".format(", ".join(disabled)) if len(disabled) > 0 else ""))
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
if (

View File

@ -79,7 +79,7 @@ import comfy.latent_formats
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {}
if model is not None:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@ -91,6 +91,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if model is not None:
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
if lora_metadata:
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
else:
k = ()
new_modelpatcher = None
@ -98,6 +100,8 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None:
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
if lora_metadata:
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
else:
k1 = ()
new_clip = None
@ -239,7 +243,8 @@ class CLIP:
model_management.archive_model_dtypes(self.cond_stage_model)
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
te_disable_dynamic = disable_dynamic or getattr(self.cond_stage_model, "disable_offload", False)
ModelPatcher = comfy.model_patcher.ModelPatcher if te_disable_dynamic else comfy.model_patcher.CoreModelPatcher
self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
@ -776,6 +781,7 @@ class VAE:
self.latent_channels = 3
self.latent_dim = 2
self.output_channels = 3
self.disable_offload = True
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
sample_rate = 16000
if sample_rate == 16000:

View File

@ -28,6 +28,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
from . import supported_models_base
from . import latent_formats
@ -1431,6 +1432,50 @@ class HiDream(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None # TODO
class HiDreamO1(supported_models_base.BASE):
unet_config = {
"image_model": "hidream_o1",
}
sampling_settings = {
"shift": 3.0,
"noise_scale": 8.0,
}
latent_format = latent_formats.HiDreamO1Pixel
memory_usage_factor = 0.033
# fp16 not supported: LM MLP down_proj activations fp16 overflow, causing NaNs
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
optimizations = {"fp8": False}
def get_model(self, state_dict, prefix="", device=None):
return model_base.HiDreamO1(self, device=device)
def process_unet_state_dict(self, state_dict):
# Drop unused Qwen3-VL deepstack merger weights; upstream discards them at inference.
for key in list(state_dict.keys()):
if "visual.deepstack_merger_list" in key:
del state_dict[key]
return state_dict
def process_vae_state_dict(self, state_dict):
# Pixel-space model: inject sentinel so VAE construction picks PixelspaceConversionVAE.
return {"pixel_space_vae": torch.tensor(1.0)}
def process_clip_state_dict(self, state_dict):
# Tokenizer-only TE: inject sentinel so load_state_dict_guess_config triggers CLIP init.
return {"_hidream_o1_te_sentinel": torch.zeros(1)}
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.hidream_o1.HiDreamO1Tokenizer,
comfy.text_encoders.hidream_o1.HiDreamO1TE,
)
class Chroma(supported_models_base.BASE):
unet_config = {
"image_model": "chroma",
@ -2018,6 +2063,7 @@ models = [
Hunyuan3Dv2,
Hunyuan3Dv2_1,
HiDream,
HiDreamO1,
Chroma,
ChromaRadiance,
ACEStep,

View File

@ -0,0 +1,119 @@
"""HiDream-O1-Image tokenizer-only text encoder.
The real Qwen3-VL backbone runs inside diffusion_model.* every step, so this
module just tokenizes the prompt into text_input_ids and emits them as
conditioning. Position ids / token_types / vinput_mask depend on target H/W
and are built later in model_base.HiDreamO1.extra_conds.
"""
import os
import torch
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
# Qwen3-VL special tokens
IM_START_ID = 151644
IM_END_ID = 151645
ASSISTANT_ID = 77091
USER_ID = 872
NEWLINE_ID = 198
VISION_START_ID = 151652
VISION_END_ID = 151653
IMAGE_TOKEN_ID = 151655
VIDEO_TOKEN_ID = 151656
# HiDream-O1-specific tokens
BOI_TOKEN_ID = 151669
BOR_TOKEN_ID = 151670
EOR_TOKEN_ID = 151671
BOT_TOKEN_ID = 151672
TMS_TOKEN_ID = 151673
class HiDreamO1QwenTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer"
)
super().__init__(
tokenizer_path,
pad_with_end=False,
embedding_size=4096,
embedding_key="hidream_o1",
tokenizer_class=Qwen2Tokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_token=151643,
tokenizer_data=tokenizer_data,
)
class HiDreamO1Tokenizer(sd1_clip.SD1Tokenizer):
"""Wraps prompt in the upstream chat template ending with boi/tms markers.
Image tokens get spliced in at sample time once target H/W is known.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="hidream_o1",
tokenizer=HiDreamO1QwenTokenizer,
)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
text_tokens_dict = super().tokenize_with_weights(
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
)
text_tuples = text_tokens_dict["hidream_o1"][0]
text_tuples = [t for t in text_tuples if int(t[0]) != 151643] # strip pad
# <|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant\n<|boi|><|tms|>
def tok(tid):
return (tid, 1.0) if not return_word_ids else (tid, 1.0, 0)
prefix = [tok(IM_START_ID), tok(USER_ID), tok(NEWLINE_ID)]
suffix = [
tok(IM_END_ID), tok(NEWLINE_ID),
tok(IM_START_ID), tok(ASSISTANT_ID), tok(NEWLINE_ID),
tok(BOI_TOKEN_ID), tok(TMS_TOKEN_ID),
]
full = prefix + list(text_tuples) + suffix
return {"hidream_o1": [full]}
class HiDreamO1TE(torch.nn.Module):
"""Passthrough TE: emits int token ids; the Qwen3-VL backbone in diffusion_model does the actual encoding."""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = {torch.float32}
self.disable_offload = True # skips dynamic VRAM management for this zero-parameter module
self.device = torch.device("cpu") if device is None else torch.device(device)
def encode_token_weights(self, token_weight_pairs):
tok_pairs = token_weight_pairs["hidream_o1"][0]
ids = [int(t[0]) for t in tok_pairs]
input_ids = torch.tensor([ids], dtype=torch.long)
# Surrogate keeps the cross_attn slot non-empty for CONDITIONING
# plumbing; the model reads text_input_ids out of `extra` instead.
cross_attn = input_ids.unsqueeze(-1).to(torch.float32)
extra = {"text_input_ids": input_ids}
return cross_attn, None, extra
def load_sd(self, sd):
return []
def get_sd(self):
return {}
def reset_clip_options(self):
pass
def set_clip_options(self, options):
pass

View File

@ -397,7 +397,7 @@ class RMSNorm(nn.Module):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None, interleaved_mrope=False):
if not isinstance(theta, list):
theta = [theta]
@ -415,16 +415,27 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
if rope_dims is not None and position_ids.shape[0] > 1 and interleaved_mrope:
# Qwen3-VL interleaved MRoPE: T-freqs by default, H/W replace every 3rd dim.
freqs_inter = freqs[0].clone()
for axis_idx, offset in ((1, 1), (2, 2)):
length = rope_dims[axis_idx] * 3
idx = slice(offset, length, 3)
freqs_inter[..., idx] = freqs[axis_idx, ..., idx]
emb = torch.cat((freqs_inter, freqs_inter), dim=-1)
cos = emb.cos().unsqueeze(0)
sin = emb.sin().unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
sin_split = sin.shape[-1] // 2
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
@ -689,6 +700,7 @@ class Llama2_(nn.Module):
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):

View File

@ -451,9 +451,8 @@ class Qwen35VisionPatchEmbed(nn.Module):
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
return self.proj(x).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):
@ -651,7 +650,7 @@ class Qwen35VisionModel(nn.Module):
x = self.patch_embed(x)
pos_embeds = self.fast_pos_embed_interpolate(grid_thw).to(x.device)
x = x + pos_embeds
rotary_pos_emb = self.rot_pos_emb(grid_thw)
rotary_pos_emb = self.rot_pos_emb(grid_thw).to(x.device)
seq_len = x.shape[0]
x = x.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
@ -761,7 +760,7 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image]
images = [image[i:i + 1] for i in range(image.shape[0])]
skip_template = False
if text.startswith('<|im_start|>'):
@ -772,13 +771,16 @@ class Qwen35ImageTokenizer(sd1_clip.SD1Tokenizer):
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
if llama_template is not None:
template = llama_template
elif len(images) == 0:
template = self.llama_template
else:
llama_text = llama_template.format(text)
template = self.llama_template_images
if len(images) > 1:
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
template = template.replace(vision_block, vision_block * len(images), 1)
llama_text = template.format(text)
if not thinking:
llama_text += "<think>\n</think>\n"

View File

@ -1164,12 +1164,18 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
o = out
o_d = out_div
ps_view = ps
mask_view = mask
for d in range(dims):
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
l = min(ps_view.shape[d + 2], o.shape[d + 2] - upscaled[d])
o = o.narrow(d + 2, upscaled[d], l)
o_d = o_d.narrow(d + 2, upscaled[d], l)
if l < ps_view.shape[d + 2]:
ps_view = ps_view.narrow(d + 2, 0, l)
mask_view = mask_view.narrow(d + 2, 0, l)
o.add_(ps * mask)
o_d.add_(mask)
o.add_(ps_view * mask_view)
o_d.add_(mask_view)
if pbar is not None:
pbar.update(1)
@ -1196,7 +1202,7 @@ def model_trange(*args, **kwargs):
pbar.i1_time = time.time()
pbar.set_postfix_str(" Model Initialization complete! ")
elif pbar._i == 2:
#bring forward the effective start time based the the diff between first and second iteration
#bring forward the effective start time based the diff between first and second iteration
#to attempt to remove load overhead from the final step rate estimate.
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
pbar.set_postfix_str("")

View File

@ -12,9 +12,24 @@ class VOXEL:
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None):
assert (vertex_counts is None) == (face_counts is None), \
"vertex_counts and face_counts must be provided together (both or neither)"
self.vertices = vertices # vertices: (B, N, 3)
self.faces = faces # faces: (B, M, 3)
self.uvs = uvs # uvs: (B, N, 2)
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
self.texture = texture # texture: (B, H, W, 3)
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts
self.face_counts = face_counts
class File3D:

View File

@ -0,0 +1,75 @@
from enum import Enum
from typing import Literal
from pydantic import BaseModel, Field
class AnthropicRole(str, Enum):
user = "user"
assistant = "assistant"
class AnthropicTextContent(BaseModel):
type: Literal["text"] = "text"
text: str = Field(...)
class AnthropicImageSourceBase64(BaseModel):
type: Literal["base64"] = "base64"
media_type: str = Field(..., description="MIME type of the image, e.g. image/png, image/jpeg")
data: str = Field(..., description="Base64-encoded image data")
class AnthropicImageSourceUrl(BaseModel):
type: Literal["url"] = "url"
url: str = Field(...)
class AnthropicImageContent(BaseModel):
type: Literal["image"] = "image"
source: AnthropicImageSourceBase64 | AnthropicImageSourceUrl = Field(...)
class AnthropicMessage(BaseModel):
role: AnthropicRole = Field(...)
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
class AnthropicMessagesRequest(BaseModel):
model: str = Field(...)
messages: list[AnthropicMessage] = Field(...)
max_tokens: int = Field(..., ge=1)
system: str | None = Field(None, description="Top-level system prompt")
temperature: float | None = Field(None, ge=0.0, le=1.0)
top_p: float | None = Field(None, ge=0.0, le=1.0)
top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = Field(None)
class AnthropicResponseTextBlock(BaseModel):
type: Literal["text"] = "text"
text: str = Field(...)
class AnthropicCacheCreationUsage(BaseModel):
ephemeral_5m_input_tokens: int | None = Field(None)
ephemeral_1h_input_tokens: int | None = Field(None)
class AnthropicMessagesUsage(BaseModel):
input_tokens: int | None = Field(None)
output_tokens: int | None = Field(None)
cache_creation_input_tokens: int | None = Field(None)
cache_read_input_tokens: int | None = Field(None)
cache_creation: AnthropicCacheCreationUsage | None = Field(None)
class AnthropicMessagesResponse(BaseModel):
id: str | None = Field(None)
type: str | None = Field(None)
role: str | None = Field(None)
model: str | None = Field(None)
content: list[AnthropicResponseTextBlock] | None = Field(None)
stop_reason: str | None = Field(None)
stop_sequence: str | None = Field(None)
usage: AnthropicMessagesUsage | None = Field(None)

View File

@ -23,7 +23,7 @@ class BriaEditImageRequest(BaseModel):
None,
description="Mask image (black and white). Black areas will be preserved, white areas will be edited. "
"If omitted, the edit applies to the entire image. "
"The input image and the the input mask must be of the same size.",
"The input image and the input mask must be of the same size.",
)
negative_prompt: str | None = Field(None)
guidance_scale: float = Field(...)

View File

@ -198,6 +198,62 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None),
]
_PRESETS_SEEDREAM_1K = [
("(1K) 1024x1024 (1:1)", 1024, 1024),
("(1K) 864x1152 (3:4)", 864, 1152),
("(1K) 1152x864 (4:3)", 1152, 864),
("(1K) 1312x736 (16:9)", 1312, 736),
("(1K) 736x1312 (9:16)", 736, 1312),
("(1K) 832x1248 (2:3)", 832, 1248),
("(1K) 1248x832 (3:2)", 1248, 832),
("(1K) 1568x672 (21:9)", 1568, 672),
]
_PRESETS_SEEDREAM_2K = [
("(2K) 2048x2048 (1:1)", 2048, 2048),
("(2K) 1728x2304 (3:4)", 1728, 2304),
("(2K) 2304x1728 (4:3)", 2304, 1728),
("(2K) 2848x1600 (16:9)", 2848, 1600),
("(2K) 1600x2848 (9:16)", 1600, 2848),
("(2K) 1664x2496 (2:3)", 1664, 2496),
("(2K) 2496x1664 (3:2)", 2496, 1664),
("(2K) 3136x1344 (21:9)", 3136, 1344),
]
_PRESETS_SEEDREAM_3K = [
("(3K) 3072x3072 (1:1)", 3072, 3072),
("(3K) 2592x3456 (3:4)", 2592, 3456),
("(3K) 3456x2592 (4:3)", 3456, 2592),
("(3K) 4096x2304 (16:9)", 4096, 2304),
("(3K) 2304x4096 (9:16)", 2304, 4096),
("(3K) 2496x3744 (2:3)", 2496, 3744),
("(3K) 3744x2496 (3:2)", 3744, 2496),
("(3K) 4704x2016 (21:9)", 4704, 2016),
]
_PRESETS_SEEDREAM_4K = [
("(4K) 4096x4096 (1:1)", 4096, 4096),
("(4K) 3520x4704 (3:4)", 3520, 4704),
("(4K) 4704x3520 (4:3)", 4704, 3520),
("(4K) 5504x3040 (16:9)", 5504, 3040),
("(4K) 3040x5504 (9:16)", 3040, 5504),
("(4K) 3328x4992 (2:3)", 3328, 4992),
("(4K) 4992x3328 (3:2)", 4992, 3328),
("(4K) 6240x2656 (21:9)", 6240, 2656),
]
_CUSTOM_PRESET = [("Custom", None, None)]
RECOMMENDED_PRESETS_SEEDREAM_5_LITE = (
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_3K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
)
RECOMMENDED_PRESETS_SEEDREAM_4_5 = (
_PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
)
RECOMMENDED_PRESETS_SEEDREAM_4_0 = (
_PRESETS_SEEDREAM_1K + _PRESETS_SEEDREAM_2K + _PRESETS_SEEDREAM_4K + _CUSTOM_PRESET
)
# Seedance 2.0 reference video pixel count limits per model and output resolution.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {

View File

@ -0,0 +1,245 @@
"""API Nodes for Anthropic Claude (Messages API). See: https://docs.anthropic.com/en/api/messages"""
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.anthropic import (
AnthropicImageContent,
AnthropicImageSourceUrl,
AnthropicMessage,
AnthropicMessagesRequest,
AnthropicMessagesResponse,
AnthropicRole,
AnthropicTextContent,
)
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
ANTHROPIC_MESSAGES_ENDPOINT = "/proxy/anthropic/v1/messages"
ANTHROPIC_IMAGE_MAX_PIXELS = 1568 * 1568
CLAUDE_MAX_IMAGES = 20
CLAUDE_MODELS: dict[str, str] = {
"Opus 4.7": "claude-opus-4-7",
"Opus 4.6": "claude-opus-4-6",
"Sonnet 4.6": "claude-sonnet-4-6",
"Sonnet 4.5": "claude-sonnet-4-5-20250929",
"Haiku 4.5": "claude-haiku-4-5-20251001",
}
def _claude_model_inputs():
return [
IO.Int.Input(
"max_tokens",
default=16000,
min=32,
max=32000,
tooltip="Maximum number of tokens to generate before stopping.",
advanced=True,
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.",
advanced=True,
),
]
def _model_price_per_million(model: str) -> tuple[float, float] | None:
"""Return (input_per_1M, output_per_1M) USD for a Claude model, or None if unknown."""
if "opus-4-7" in model or "opus-4-6" in model or "opus-4-5" in model:
return 5.0, 25.0
if "sonnet-4" in model:
return 3.0, 15.0
if "haiku-4-5" in model:
return 1.0, 5.0
return None
def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None:
"""Compute approximate USD price from response usage. Server-side billing is authoritative."""
if not response.usage or not response.model:
return None
rates = _model_price_per_million(response.model)
if rates is None:
return None
input_rate, output_rate = rates
input_tokens = response.usage.input_tokens or 0
output_tokens = response.usage.output_tokens or 0
cache_read = response.usage.cache_read_input_tokens or 0
cache_5m = 0
cache_1h = 0
if response.usage.cache_creation:
cache_5m = response.usage.cache_creation.ephemeral_5m_input_tokens or 0
cache_1h = response.usage.cache_creation.ephemeral_1h_input_tokens or 0
total = (
input_tokens * input_rate
+ output_tokens * output_rate
+ cache_read * input_rate * 0.1
+ cache_5m * input_rate * 1.25
+ cache_1h * input_rate * 2.0
)
return total / 1_000_000.0
def _get_text_from_response(response: AnthropicMessagesResponse) -> str:
if not response.content:
return ""
return "\n".join(block.text for block in response.content if block.text)
async def _build_image_content_blocks(
cls: type[IO.ComfyNode],
image_tensors: list[Input.Image],
) -> list[AnthropicImageContent]:
urls = await upload_images_to_comfyapi(
cls,
image_tensors,
max_images=CLAUDE_MAX_IMAGES,
total_pixels=ANTHROPIC_IMAGE_MAX_PIXELS,
wait_label="Uploading reference images",
)
return [AnthropicImageContent(source=AnthropicImageSourceUrl(url=url)) for url in urls]
class ClaudeNode(IO.ComfyNode):
"""Generate text responses from an Anthropic Claude model."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ClaudeNode",
display_name="Anthropic Claude",
category="api node/text/Anthropic",
essentials_category="Text Generation",
description="Generate text responses with Anthropic's Claude models. "
"Provide a text prompt and optionally one or more images for multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model.",
),
IO.DynamicCombo.Input(
"model",
options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS],
tooltip="The Claude model used to generate the response.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, CLAUDE_MAX_IMAGES + 1)],
min=0,
),
tooltip=f"Optional image(s) to use as context for the model. Up to {CLAUDE_MAX_IMAGES} images.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[IO.String.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "opus") ? {
"type": "list_usd",
"usd": [0.005, 0.025],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "sonnet") ? {
"type": "list_usd",
"usd": [0.003, 0.015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "haiku") ? {
"type": "list_usd",
"usd": [0.001, 0.005],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
images: dict | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_label = model["model"]
max_tokens = model["max_tokens"]
temperature = None if model_label == "Opus 4.7" else model["temperature"]
image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None]
if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES:
raise ValueError(f"Up to {CLAUDE_MAX_IMAGES} images are supported per request.")
content: list[AnthropicTextContent | AnthropicImageContent] = []
if image_tensors:
content.extend(await _build_image_content_blocks(cls, image_tensors))
content.append(AnthropicTextContent(text=prompt))
response = await sync_op(
cls,
ApiEndpoint(path=ANTHROPIC_MESSAGES_ENDPOINT, method="POST"),
response_model=AnthropicMessagesResponse,
data=AnthropicMessagesRequest(
model=CLAUDE_MODELS[model_label],
max_tokens=max_tokens,
messages=[AnthropicMessage(role=AnthropicRole.user, content=content)],
system=system_prompt or None,
temperature=temperature,
),
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(_get_text_from_response(response) or "Empty response from Claude model.")
class AnthropicExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ClaudeNode]
async def comfy_entrypoint() -> AnthropicExtension:
return AnthropicExtension()

View File

@ -596,6 +596,7 @@ class Flux2ProImageNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
expr=cls.PRICE_BADGE_EXPR,
),
is_deprecated=True,
)
@classmethod
@ -674,6 +675,175 @@ class Flux2MaxImageNode(Flux2ProImageNode):
"""
_FLUX2_MODEL_ENDPOINTS = {
"Flux.2 [pro]": "/proxy/bfl/flux-2-pro/generate",
"Flux.2 [max]": "/proxy/bfl/flux-2-max/generate",
}
def _flux2_model_inputs():
return [
IO.Int.Input(
"width",
default=1024,
min=256,
max=2048,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=2048,
step=32,
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 9)],
min=0,
),
tooltip="Optional reference image(s) for image-to-image generation. Up to 8 images.",
),
]
class Flux2ImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Flux2ImageNode",
display_name="Flux.2 Image",
category="api node/image/BFL",
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation or edit",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Flux.2 [pro]", _flux2_model_inputs()),
IO.DynamicCombo.Option("Flux.2 [max]", _flux2_model_inputs()),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.width", "model.height"],
input_groups=["model.images"],
),
expr="""
(
$isMax := widgets.model = "flux.2 [max]";
$MP := 1024 * 1024;
$w := $lookup(widgets, "model.width");
$h := $lookup(widgets, "model.height");
$outMP := $max([1, $floor((($w * $h) + $MP - 1) / $MP)]);
$outputCost := $isMax
? (0.07 + 0.03 * ($outMP - 1))
: (0.03 + 0.015 * ($outMP - 1));
$refMin := $isMax ? 0.03 : 0.015;
$refMax := $isMax ? 0.24 : 0.12;
$hasRefs := $lookup(inputGroups, "model.images") > 0;
$hasRefs
? {
"type": "range_usd",
"min_usd": $outputCost + $refMin,
"max_usd": $outputCost + $refMax,
"format": { "approximate": true }
}
: {"type": "usd", "usd": $outputCost}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
model_choice = model["model"]
endpoint = _FLUX2_MODEL_ENDPOINTS[model_choice]
width = model["width"]
height = model["height"]
images_dict = model.get("images") or {}
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
n_images = sum(get_number_of_images(t) for t in image_tensors)
if n_images > 8:
raise ValueError("The current maximum number of supported images is 8.")
flat_tensors: list[torch.Tensor] = []
for tensor in image_tensors:
if len(tensor.shape) == 4:
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
else:
flat_tensors.append(tensor)
reference_images: dict[str, str] = {}
for idx, tensor in enumerate(flat_tensors):
key_name = f"input_image_{idx + 1}" if idx else "input_image"
reference_images[key_name] = tensor_to_base64_string(tensor, total_pixels=2048 * 2048)
initial_response = await sync_op(
cls,
ApiEndpoint(path=endpoint, method="POST"),
response_model=BFLFluxProGenerateResponse,
data=Flux2ProGenerateRequest(
prompt=prompt,
width=width,
height=height,
seed=seed,
**reference_images,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class BFLExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -685,6 +855,7 @@ class BFLExtension(ComfyExtension):
FluxProFillNode,
Flux2ProImageNode,
Flux2MaxImageNode,
Flux2ImageNode,
]

View File

@ -10,6 +10,9 @@ from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
RECOMMENDED_PRESETS_SEEDREAM_4_0,
RECOMMENDED_PRESETS_SEEDREAM_4_5,
RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
SEEDANCE2_PRICE_PER_1K_TOKENS,
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS,
VIDEO_TASKS_EXECUTION_TIME,
@ -68,6 +71,12 @@ SEEDREAM_MODELS = {
"seedream-4-0-250828": "seedream-4-0-250828",
}
SEEDREAM_PRESETS = {
"seedream-5-0-260128": RECOMMENDED_PRESETS_SEEDREAM_5_LITE,
"seedream-4-5-251128": RECOMMENDED_PRESETS_SEEDREAM_4_5,
"seedream-4-0-250828": RECOMMENDED_PRESETS_SEEDREAM_4_0,
}
# Long-running tasks endpoints(e.g., video)
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
@ -562,6 +571,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
)
""",
),
is_deprecated=True,
)
@classmethod
@ -651,6 +661,226 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
def _seedream_model_inputs(*, max_ref_images: int, presets: list):
return [
IO.Combo.Input(
"size_preset",
options=[label for label, _, _ in presets],
tooltip="Pick a recommended size. Select Custom to use the width and height below.",
),
IO.Int.Input(
"width",
default=2048,
min=1024,
max=6240,
step=2,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
),
IO.Int.Input(
"height",
default=2048,
min=1024,
max=4992,
step=2,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
),
IO.Int.Input(
"max_images",
default=1,
min=1,
max=max_ref_images,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Maximum number of images to generate. With 1, exactly one image is produced. "
"With >1, the model generates between 1 and max_images related images "
"(e.g., story scenes, character variations). "
"Total images (input + generated) cannot exceed 15.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
min=0,
),
tooltip=f"Optional reference image(s) for image-to-image or multi-reference generation. "
f"Up to {max_ref_images} images.",
),
IO.Boolean.Input(
"fail_on_partial",
default=False,
tooltip="If enabled, abort execution if any requested images are missing or return an error.",
advanced=True,
),
]
class ByteDanceSeedreamNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ByteDanceSeedreamNodeV2",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for creating or editing an image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"seedream 5.0 lite",
_seedream_model_inputs(max_ref_images=14, presets=RECOMMENDED_PRESETS_SEEDREAM_5_LITE),
),
IO.DynamicCombo.Option(
"seedream-4-5-251128",
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_5),
),
IO.DynamicCombo.Option(
"seedream-4-0-250828",
_seedream_model_inputs(max_ref_images=10, presets=RECOMMENDED_PRESETS_SEEDREAM_4_0),
),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip='Whether to add an "AI generated" watermark to the image.',
advanced=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$price := $contains(widgets.model, "5.0 lite") ? 0.035 :
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
{
"type":"usd",
"usd": $price,
"format": { "suffix":" x images/Run", "approximate": true }
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int = 0,
watermark: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = SEEDREAM_MODELS[model["model"]]
presets = SEEDREAM_PRESETS[model_id]
size_preset = model.get("size_preset", presets[0][0])
width = model.get("width", 2048)
height = model.get("height", 2048)
max_images = model.get("max_images", 1)
sequential_image_generation = "disabled" if max_images == 1 else "auto"
images_dict = model.get("images") or {}
fail_on_partial = model.get("fail_on_partial", False)
w = h = None
for label, tw, th in presets:
if label == size_preset:
w, h = tw, th
break
if w is None or h is None:
w, h = width, height
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if ("seedream-4-5" in model_id or "seedream-5-0" in model_id) and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution for the selected model is 3.68MP, but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model_id and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
if out_num_pixels > 16_777_216:
raise ValueError(
f"Maximum image resolution for the selected model is 16.78MP, but {mp_provided:.2f}MP provided."
)
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
n_input_images = sum(get_number_of_images(t) for t in image_tensors)
max_num_of_images = 14 if model_id == "seedream-5-0-260128" else 10
if n_input_images > max_num_of_images:
raise ValueError(
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
)
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
raise ValueError(
"The maximum number of generated images plus the number of reference images cannot exceed 15."
)
reference_images_urls: list[str] = []
if image_tensors:
for tensor in image_tensors:
validate_image_aspect_ratio(tensor, (1, 3), (3, 1))
reference_images_urls = await upload_images_to_comfyapi(
cls,
image_tensors,
max_images=n_input_images,
mime_type="image/png",
wait_label="Uploading reference images",
)
response = await sync_op(
cls,
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
response_model=ImageTaskCreationResponse,
data=Seedream4TaskCreationRequest(
model=model_id,
prompt=prompt,
image=reference_images_urls,
size=f"{w}x{h}",
seed=seed,
sequential_image_generation=sequential_image_generation,
sequential_image_generation_options=Seedream4Options(max_images=max_images),
watermark=watermark,
),
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
if fail_on_partial and len(urls) < len(response.data):
raise RuntimeError(f"Only {len(urls)} of {len(response.data)} images were generated before error.")
return IO.NodeOutput(torch.cat([await download_url_to_image_tensor(i) for i in urls]))
class ByteDanceTextToVideoNode(IO.ComfyNode):
@classmethod
@ -2105,6 +2335,7 @@ class ByteDanceExtension(ComfyExtension):
return [
ByteDanceImageNode,
ByteDanceSeedreamNode,
ByteDanceSeedreamNodeV2,
ByteDanceTextToVideoNode,
ByteDanceImageToVideoNode,
ByteDanceFirstLastFrameNode,

View File

@ -162,6 +162,61 @@ class GrokImageNode(IO.ComfyNode):
)
_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS = [
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
]
def _grok_image_edit_model_inputs(*, max_ref_images: int, with_aspect_ratio: bool):
inputs = [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_ref_images + 1)],
min=1,
),
tooltip=(
"Reference image to edit."
if max_ref_images == 1
else f"Reference image(s) to edit. Up to {max_ref_images} images."
),
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input(
"number_of_images",
default=1,
min=1,
max=10,
step=1,
tooltip="Number of edited images to generate",
display_mode=IO.NumberDisplay.number,
),
]
if with_aspect_ratio:
inputs.append(
IO.Combo.Input(
"aspect_ratio",
options=_GROK_IMAGE_EDIT_ASPECT_RATIO_OPTIONS,
tooltip="Only allowed when multiple images are connected.",
)
)
return inputs
class GrokImageEditNode(IO.ComfyNode):
@classmethod
@ -256,6 +311,7 @@ class GrokImageEditNode(IO.ComfyNode):
)
""",
),
is_deprecated=True,
)
@classmethod
@ -303,6 +359,143 @@ class GrokImageEditNode(IO.ComfyNode):
)
class GrokImageEditNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GrokImageEditNodeV2",
display_name="Grok Image Edit",
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="The text prompt used to generate the image",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"grok-imagine-image-quality",
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
),
IO.DynamicCombo.Option(
"grok-imagine-image-pro",
_grok_image_edit_model_inputs(max_ref_images=1, with_aspect_ratio=False),
),
IO.DynamicCombo.Option(
"grok-imagine-image",
_grok_image_edit_model_inputs(max_ref_images=3, with_aspect_ratio=True),
),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.resolution", "model.number_of_images"],
),
expr="""
(
$isQualityModel := widgets.model = "grok-imagine-image-quality";
$isPro := $contains(widgets.model, "pro");
$res := $lookup(widgets, "model.resolution");
$n := $lookup(widgets, "model.number_of_images");
$rate := $isQualityModel
? ($res = "1k" ? 0.05 : 0.07)
: ($isPro ? 0.07 : 0.02);
$base := $isQualityModel ? 0.01 : 0.002;
$output := $rate * $n;
$isPro
? {"type":"usd","usd": $base + $output}
: {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = model["model"]
resolution = model["resolution"]
number_of_images = model["number_of_images"]
images_dict = model.get("images") or {}
aspect_ratio = model.get("aspect_ratio", "auto")
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
n_images = sum(get_number_of_images(t) for t in image_tensors)
if n_images < 1:
raise ValueError("At least one image is required for editing.")
if model_id == "grok-imagine-image-pro" and n_images > 1:
raise ValueError("The pro model supports only 1 input image.")
if model_id != "grok-imagine-image-pro" and n_images > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and n_images == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
flat_tensors: list[torch.Tensor] = []
for tensor in image_tensors:
if len(tensor.shape) == 4:
flat_tensors.extend(tensor[i] for i in range(tensor.shape[0]))
else:
flat_tensors.append(tensor)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model_id,
images=[
InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in flat_tensors
],
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
return IO.NodeOutput(
torch.cat(
[await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]],
)
)
class GrokVideoNode(IO.ComfyNode):
@classmethod
@ -737,6 +930,7 @@ class GrokExtension(ComfyExtension):
return [
GrokImageNode,
GrokImageEditNode,
GrokImageEditNodeV2,
GrokVideoNode,
GrokVideoReferenceNode,
GrokVideoEditNode,

View File

@ -27,6 +27,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_bytesio,
downscale_image_tensor,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string,
@ -372,6 +373,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
is_deprecated=True,
inputs=[
IO.String.Input(
"prompt",
@ -640,6 +642,316 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.NodeOutput(await validate_and_cast_response(response))
def _gpt_image_shared_inputs():
"""Inputs shared by all GPT Image models (quality + reference images + mask)."""
return [
IO.Combo.Input(
"quality",
default="low",
options=["low", "medium", "high"],
tooltip="Image quality, affects cost and generation time.",
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Optional reference image(s) for image editing. Up to 16 images.",
),
IO.Mask.Input(
"mask",
optional=True,
tooltip="Optional mask for inpainting (white areas will be replaced). "
"Requires exactly one reference image.",
),
]
def _gpt_image_legacy_model_inputs():
"""Per-model widget set for legacy gpt-image-1 / gpt-image-1.5 (4 base sizes, transparent bg allowed)."""
return [
IO.Combo.Input(
"size",
default="auto",
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
tooltip="Image size.",
),
IO.Combo.Input(
"background",
default="auto",
options=["auto", "opaque", "transparent"],
tooltip="Return image with or without background.",
),
*_gpt_image_shared_inputs(),
]
class OpenAIGPTImageNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="OpenAIGPTImageNodeV2",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
description="Generates images via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
"prompt",
default="",
multiline=True,
tooltip="Text prompt for GPT Image",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"gpt-image-2",
[
IO.Combo.Input(
"size",
default="auto",
options=[
"auto",
"1024x1024",
"1024x1536",
"1536x1024",
"2048x2048",
"2048x1152",
"1152x2048",
"3840x2160",
"2160x3840",
"Custom",
],
tooltip="Image size. Select 'Custom' to use the custom width and height.",
),
IO.Int.Input(
"custom_width",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
),
IO.Int.Input(
"custom_height",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
),
IO.Combo.Input(
"background",
default="auto",
options=["auto", "opaque"],
tooltip="Return image with or without background.",
),
*_gpt_image_shared_inputs(),
],
),
IO.DynamicCombo.Option("gpt-image-1.5", _gpt_image_legacy_model_inputs()),
IO.DynamicCombo.Option("gpt-image-1", _gpt_image_legacy_model_inputs()),
],
),
IO.Int.Input(
"n",
default=1,
min=1,
max=8,
step=1,
tooltip="How many images to generate",
display_mode=IO.NumberDisplay.number,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="not implemented yet in backend",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.quality", "n"]),
expr="""
(
$ranges := {
"gpt-image-1": {
"low": [0.011, 0.02],
"medium": [0.042, 0.07],
"high": [0.167, 0.25]
},
"gpt-image-1.5": {
"low": [0.009, 0.02],
"medium": [0.034, 0.062],
"high": [0.133, 0.22]
},
"gpt-image-2": {
"low": [0.0048, 0.019],
"medium": [0.041, 0.168],
"high": [0.165, 0.67]
}
};
$range := $lookup($lookup($ranges, widgets.model), $lookup(widgets, "model.quality"));
$nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
: {
"type":"range_usd",
"min_usd": $range[0] * $n,
"max_usd": $range[1] * $n,
"format": { "suffix": "/Run", "approximate": true }
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
n: int,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
model_id = model["model"]
size = model["size"]
background = model["background"]
quality = model["quality"]
custom_width = model.get("custom_width", 1024)
custom_height = model.get("custom_height", 1024)
images_dict = model.get("images") or {}
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
n_images = sum(get_number_of_images(t) for t in image_tensors)
mask = model.get("mask")
if mask is not None and n_images == 0:
raise ValueError("Cannot use a mask without an input image")
if size == "Custom":
if custom_width % 16 != 0 or custom_height % 16 != 0:
raise ValueError(
f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}"
)
if max(custom_width, custom_height) > 3840:
raise ValueError(
f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}"
)
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
if ratio > 3:
raise ValueError(
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
)
total_pixels = custom_width * custom_height
if not 655_360 <= total_pixels <= 8_294_400:
raise ValueError(
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
)
size = f"{custom_width}x{custom_height}"
if model_id == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1
elif model_id == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5
elif model_id == "gpt-image-2":
price_extractor = calculate_tokens_price_image_2_0
else:
raise ValueError(f"Unknown model: {model_id}")
if image_tensors:
flat: list[torch.Tensor] = []
for tensor in image_tensors:
if len(tensor.shape) == 4:
flat.extend(tensor[i : i + 1] for i in range(tensor.shape[0]))
else:
flat.append(tensor.unsqueeze(0))
files = []
for i, single_image in enumerate(flat):
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
if len(flat) == 1:
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else:
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if len(flat) != 1:
raise Exception("Cannot use a mask with multiple image")
ref_image = flat[0]
if mask.shape[1:] != ref_image.shape[1:-1]:
raise Exception("Mask and Image must be the same size")
_, height, width = mask.shape
rgba_mask = torch.zeros(height, width, 4, device="cpu")
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
scaled_mask = downscale_image_tensor(
rgba_mask.unsqueeze(0), total_pixels=2048 * 2048
).squeeze()
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_img_byte_arr = BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageEditRequest(
model=model_id,
prompt=prompt,
quality=quality,
background=background,
n=n,
size=size,
moderation="low",
),
content_type="multipart/form-data",
files=files,
price_extractor=price_extractor,
)
else:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
response_model=OpenAIImageGenerationResponse,
data=OpenAIImageGenerationRequest(
model=model_id,
prompt=prompt,
quality=quality,
background=background,
n=n,
size=size,
moderation="low",
),
price_extractor=price_extractor,
)
return IO.NodeOutput(await validate_and_cast_response(response))
class OpenAIChatNode(IO.ComfyNode):
"""
Node to generate text responses from an OpenAI model.
@ -999,6 +1311,7 @@ class OpenAIExtension(ComfyExtension):
OpenAIDalle2,
OpenAIDalle3,
OpenAIGPTImage1,
OpenAIGPTImageNodeV2,
OpenAIChatNode,
OpenAIInputFiles,
OpenAIChatConfig,

View File

@ -143,7 +143,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
url = await upload_image_to_comfyapi(cls, reference_images[key], mime_type="image/png")
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
@ -252,7 +252,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
image_url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
response = await sync_op(
cls,

View File

@ -86,6 +86,37 @@ def sample_euler_pp(model, x, sigmas, extra_args=None, callback=None, disable=No
return x
class SamplerLCM(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCM",
category="sampling/samplers",
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
inputs=[
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the first step (1.0 = match training)."),
io.Float.Input("s_noise_end", default=1.0, min=0.0, max=64.0, step=0.01,
tooltip="Per-step noise multiplier at the last step. Set equal to s_noise for a constant schedule."),
io.Float.Input("noise_clip_std", default=0.0, min=0.0, max=10.0, step=0.01,
tooltip="Clamp per-step noise to +/- N*std. 0 disables."),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, s_noise, s_noise_end, noise_clip_std) -> io.NodeOutput:
sampler = comfy.samplers.ksampler(
"lcm",
{
"s_noise": float(s_noise),
"s_noise_end": float(s_noise_end),
"noise_clip_std": float(noise_clip_std),
},
)
return io.NodeOutput(sampler)
class SamplerEulerCFGpp(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
@ -114,6 +145,7 @@ class AdvancedSamplersExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SamplerLCMUpscale,
SamplerLCM,
SamplerEulerCFGpp,
]

View File

@ -82,6 +82,8 @@ class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("VAEEncodeAudio: input audio is None (source video may have no audio track).")
sample_rate = audio["sample_rate"]
vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
@ -171,6 +173,8 @@ class SaveAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
@ -198,6 +202,8 @@ class SaveAudioMP3(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -226,6 +232,8 @@ class SaveAudioOpus(IO.ComfyNode):
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
if audio is None:
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
@ -252,6 +260,8 @@ class PreviewAudio(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
@ -297,6 +307,7 @@ class LoadAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
os.makedirs(input_dir, exist_ok=True)
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return IO.Schema(
node_id="LoadAudio",
@ -391,21 +402,26 @@ class TrimAudioDuration(IO.ComfyNode):
@classmethod
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if audio_length == 0:
return IO.NodeOutput(audio)
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
start_frame = max(0, min(start_frame, audio_length))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
raise ValueError("TrimAudioDuration: Start time must be less than end time and be within the audio length.")
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
@ -432,11 +448,13 @@ class SplitAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None, None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
raise ValueError(f"AudioSplit: Input audio must be stereo (2 channels), got {waveform.shape[1]} channel(s).")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
@ -464,6 +482,12 @@ class JoinAudioChannels(IO.ComfyNode):
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
if audio_left is None and audio_right is None:
return IO.NodeOutput(None)
if audio_left is None:
return IO.NodeOutput(audio_right)
if audio_right is None:
return IO.NodeOutput(audio_left)
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
@ -537,6 +561,12 @@ class AudioConcat(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -584,6 +614,12 @@ class AudioMerge(IO.ComfyNode):
@classmethod
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
if audio1 is None and audio2 is None:
return IO.NodeOutput(None)
if audio1 is None:
return IO.NodeOutput(audio2)
if audio2 is None:
return IO.NodeOutput(audio1)
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -594,6 +630,9 @@ class AudioMerge(IO.ComfyNode):
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_1 == 0 or length_2 == 0:
return IO.NodeOutput({"waveform": waveform_1, "sample_rate": output_sample_rate})
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
@ -645,6 +684,8 @@ class AudioAdjustVolume(IO.ComfyNode):
@classmethod
def execute(cls, audio, volume) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
if volume == 0:
return IO.NodeOutput(audio)
waveform = audio["waveform"]
@ -728,8 +769,14 @@ class AudioEqualizer3Band(IO.ComfyNode):
@classmethod
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
if audio is None:
return IO.NodeOutput(None)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[-1] == 0:
return IO.NodeOutput(audio)
eq_waveform = waveform.clone()
# 1. Apply Low Shelf (Bass)

View File

@ -0,0 +1,256 @@
from typing_extensions import override
import torch
import comfy.model_management
import comfy.patcher_extension
import node_helpers
from comfy_api.latest import ComfyExtension, io
class EmptyHiDreamO1LatentImage(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. The model was "
"trained at ~4 megapixels; lower resolutions go off-distribution "
"and quality regresses noticeably. Trained resolutions: "
"2048x2048, 2304x1728, 1728x2304, 2560x1440, 1440x2560, "
"2496x1664, 1664x2496, 3104x1312, 1312x3104, 2304x1792, 1792x2304."
),
inputs=[
io.Int.Input(id="width", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="height", default=2048, min=64, max=4096, step=32),
io.Int.Input(id="batch_size", default=1, min=1, max=64),
],
outputs=[io.Latent().Output()],
)
@classmethod
def execute(cls, *, width: int, height: int, batch_size: int = 1) -> io.NodeOutput:
latent = torch.zeros(
(batch_size, 3, height, width),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class HiDreamO1ReferenceImages(io.ComfyNode):
"""Attach reference images to both positive and negative conditioning."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."
),
inputs=[
io.Conditioning.Input(id="positive"),
io.Conditioning.Input(id="negative"),
io.Autogrow.Input(
"images",
template=io.Autogrow.TemplateNames(
io.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 11)],
min=1,
),
tooltip=("Reference images. 1 image = instruction edit; 2-10 images = multi reference."
),
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, *, positive, negative, images: io.Autogrow.Type) -> io.NodeOutput:
refs = [images[f"image_{i}"] for i in range(1, 11) if f"image_{i}" in images]
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": refs}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": refs}, append=True)
return io.NodeOutput(positive, negative)
class HiDreamO1PatchSeamSmoothing(io.ComfyNode):
PATCH_SIZE = 32
EDGE_FEATHER = 4
# Shift presets per (pattern, N). 8-pass = 4-quadrant + 4 quarter-patch offsets.
SHIFTS_BY_PATTERN = {
("single_shift", 2): [(0, 0), (16, 16)],
("single_shift", 4): [(0, 0), (16, 0), (0, 16), (16, 16)],
("single_shift", 8): [(0, 0), (16, 0), (0, 16), (16, 16),
(8, 8), (24, 8), (8, 24), (24, 24)],
("symmetric", 2): [(-8, -8), (8, 8)],
("symmetric", 4): [(-8, -8), (8, -8), (-8, 8), (8, 8)],
("symmetric", 8): [(-12, -12), (4, -12), (-12, 4), (4, 4),
(-4, -4), (12, -4), (-4, 12), (12, 12)],
}
RAMP_LEVELS = {
"2": [2],
"4": [4],
"ramp_2_4": [2, 4],
"ramp_2_4_8": [2, 4, 8],
}
@staticmethod
def _hann_tile(cy: int, cx: int, size: int = 32) -> torch.Tensor:
"""size x size Hann tile peaking at (cy, cx) within a patch."""
half = size // 2
yy = torch.arange(size).view(size, 1)
xx = torch.arange(size).view(1, size)
dy = ((yy - cy + half) % size) - half
dx = ((xx - cx + half) % size) - half
return 0.25 * (1 + torch.cos(torch.pi * dy / half)) * (1 + torch.cos(torch.pi * dx / half))
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="HiDreamO1PatchSeamSmoothing",
display_name="HiDream-O1 Patch Seam Smoothing",
category="advanced/model",
is_experimental=True,
description=(
"Average the model output across multiple shifted patch-grid "
"positions during the late portion of sampling. Cancels seams."
),
inputs=[
io.Model.Input(id="model"),
io.Float.Input(id="start_percent", default=0.8, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress (0=start, 1=end) at which the blend turns ON.",
),
io.Float.Input(id="end_percent", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Sampling progress at which the blend turns OFF.",
),
io.Combo.Input(
id="pattern",
options=["single_shift", "symmetric"],
default="single_shift",
tooltip="Shift layout. single_shift: one pass at the natural patch grid + others offset. symmetric: all passes off-grid, shifts split around origin.",
),
io.Combo.Input(
id="passes",
options=["2", "4", "ramp_2_4", "ramp_2_4_8"],
default="2",
tooltip="Number of passes per gated step. 2/4 = fixed. ramp_*: pass count increases as sampling approaches end (more smoothing where seams are most visible).",
),
io.Combo.Input(
id="blend",
options=["average", "window", "median"],
default="average",
tooltip="average: equal-weight mean. window: Hann-windowed weighting favoring each pass away from its patch boundaries. median: per-pixel median, rejects wraparound-outlier passes.",
),
io.Float.Input(id="strength", default=1.0, min=0.0, max=1.0, step=0.01,
tooltip="Interpolation between the natural-grid pred (0) and the averaged result (1).",
),
],
outputs=[io.Model.Output()],
)
@classmethod
def execute(cls, *, model, start_percent: float, end_percent: float, pattern: str, passes: str, blend: str, strength: float) -> io.NodeOutput:
if strength <= 0.0 or end_percent <= start_percent:
return io.NodeOutput(model)
P = cls.PATCH_SIZE
half = P // 2
shift_levels = [cls.SHIFTS_BY_PATTERN[(pattern, n)] for n in cls.RAMP_LEVELS[passes]]
if blend == "window":
window_tile_levels = [
torch.stack([cls._hann_tile((half - sy) % P, (half - sx) % P, P) for sy, sx in lst], dim=0)
for lst in shift_levels
]
else:
window_tile_levels = [None] * len(shift_levels)
m = model.clone()
model_sampling = m.get_model_object("model_sampling")
multiplier = float(model_sampling.multiplier)
start_t = float(model_sampling.percent_to_sigma(start_percent)) * multiplier
end_t = float(model_sampling.percent_to_sigma(end_percent)) * multiplier
edge_ramp_cache: dict = {}
def get_edge_ramp(H: int, W: int, device, dtype) -> torch.Tensor:
key = (H, W, device, dtype)
cached = edge_ramp_cache.get(key)
if cached is not None:
return cached
feather = cls.EDGE_FEATHER
ys = torch.minimum(torch.arange(H, device=device, dtype=torch.float32),
(H - 1) - torch.arange(H, device=device, dtype=torch.float32))
xs = torch.minimum(torch.arange(W, device=device, dtype=torch.float32),
(W - 1) - torch.arange(W, device=device, dtype=torch.float32))
y_mask = ((ys - P) / feather).clamp(0, 1)
x_mask = ((xs - P) / feather).clamp(0, 1)
ramp = (y_mask[:, None] * x_mask[None, :]).to(dtype)
edge_ramp_cache[key] = ramp
return ramp
def smoothing_wrapper(executor, *args, **kwargs):
x = args[0]
t = float(args[1][0])
pred = executor(*args, **kwargs)
if not (end_t <= t <= start_t):
return pred
# Pick shift-level by sigma phase across the gated range.
if len(shift_levels) == 1:
level_idx = 0
else:
phase = (start_t - t) / max(start_t - end_t, 1e-8)
level_idx = min(int(phase * len(shift_levels)), len(shift_levels) - 1)
shifts = shift_levels[level_idx]
window_tiles = window_tile_levels[level_idx]
preds = []
for sy, sx in shifts:
if sy == 0 and sx == 0:
preds.append(pred)
continue
x_rolled = torch.roll(x, shifts=(sy, sx), dims=(-2, -1))
pred_rolled = executor(x_rolled, *args[1:], **kwargs)
preds.append(torch.roll(pred_rolled, shifts=(-sy, -sx), dims=(-2, -1)))
stacked = torch.stack(preds, dim=0) # (N, B, C, H, W)
_, _, _, H, W = stacked.shape
if blend == "window":
N = stacked.shape[0]
tiles = window_tiles.to(device=stacked.device, dtype=stacked.dtype)
w = tiles.repeat(1, H // P, W // P)[:, :H, :W]
sum_w = w.sum(dim=0, keepdim=True)
w = torch.where(sum_w < 1e-3, torch.full_like(w, 1.0 / N), w / sum_w.clamp(min=1e-8))
avg = (stacked * w[:, None, None, :, :]).sum(dim=0)
elif blend == "median":
avg = torch.median(stacked, dim=0).values
else:
avg = stacked.mean(dim=0)
# Mask out the P-px wraparound contamination strip at each edge.
mask = get_edge_ramp(H, W, pred.device, pred.dtype)
return pred * (1.0 - mask * strength) + avg * (mask * strength)
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "hidream_o1_patch_seam_smoothing", smoothing_wrapper)
return io.NodeOutput(m)
class HiDreamO1Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyHiDreamO1LatentImage,
HiDreamO1ReferenceImages,
HiDreamO1PatchSeamSmoothing,
]
async def comfy_entrypoint() -> HiDreamO1Extension:
return HiDreamO1Extension()

View File

@ -1,12 +1,7 @@
import torch
import os
import json
import struct
import numpy as np
from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
import folder_paths
import comfy.model_management
from comfy.cli_args import args
from comfy_extras.nodes_save_3d import pack_variable_mesh_batch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
@ -444,7 +439,9 @@ class VoxelToMeshBasic(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
decode = execute # TODO: remove
@ -481,206 +478,13 @@ class VoxelToMesh(IO.ComfyNode):
vertices.append(v)
faces.append(f)
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
decode = execute # TODO: remove
def save_glb(vertices, faces, filepath, metadata=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
filepath: str - Output filepath (should end with .glb)
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_np = faces.cpu().numpy().astype(np.uint32)
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
buffer_data = vertices_buffer_padded + indices_buffer_padded
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [
{
"byteLength": len(buffer_data)
}
],
"bufferViews": [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
],
"accessors": [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
],
"meshes": [
{
"primitives": [
{
"attributes": {
"POSITION": 0
},
"indices": 1,
"mode": 4 # TRIANGLES
}
]
}
],
"nodes": [
{
"mesh": 0
}
],
"scenes": [
{
"nodes": [0]
}
],
"scene": 0
}
if metadata is not None:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header
# Magic glTF
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class SaveGLB(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
essentials_category="Basics",
is_output_node=True,
inputs=[
IO.MultiType.Input(
IO.Mesh.Input("mesh"),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="Mesh or 3D file to save",
),
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
if isinstance(mesh, Types.File3D):
# Handle File3D input - save BytesIO data to output folder
ext = mesh.format or "glb"
f = f"{filename}_{counter:05}_.{ext}"
mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
else:
# Handle Mesh input - save vertices and faces as GLB
for i in range(mesh.vertices.shape[0]):
f = f"{filename}_{counter:05}_.glb"
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return IO.NodeOutput(ui={"3d": results})
class Hunyuan3dExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -691,7 +495,6 @@ class Hunyuan3dExtension(ComfyExtension):
VAEDecodeHunyuan3D,
VoxelToMeshBasic,
VoxelToMesh,
SaveGLB,
]

View File

@ -136,7 +136,7 @@ class ImageFromBatch(IO.ComfyNode):
category="image/batch",
inputs=[
IO.Image.Input("image"),
IO.Int.Input("batch_index", default=0, min=0, max=4095),
IO.Int.Input("batch_index", default=0, min=-MAX_RESOLUTION, max=MAX_RESOLUTION),
IO.Int.Input("length", default=1, min=1, max=4096),
],
outputs=[IO.Image.Output()],
@ -145,7 +145,9 @@ class ImageFromBatch(IO.ComfyNode):
@classmethod
def execute(cls, image, batch_index, length) -> IO.NodeOutput:
s_in = image
batch_index = min(s_in.shape[0] - 1, batch_index)
if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
length = min(s_in.shape[0] - batch_index, length)
s = s_in[batch_index:batch_index + length].clone()
return IO.NodeOutput(s)

View File

@ -14,6 +14,49 @@ from typing_extensions import override
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy_api.latest import ComfyExtension, io
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
class GetICLoRAParameters(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetICLoRAParameters",
display_name="Get IC-LoRA Parameters",
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
category="conditioning/video_models",
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
inputs=[
io.Model.Input(
"iclora_model",
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
"from which to extract the metadata.",
),
],
outputs=[
ICLoRAParameters.Output(
"iclora_parameters",
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
"if the LoRA requires special handling of the guides.",
),
],
)
@classmethod
def execute(cls, iclora_model) -> io.NodeOutput:
metadata = iclora_model.get_attachment("lora_metadata")
factor = 1
if metadata:
try:
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
parameters = {"reference_downscale_factor": factor}
return io.NodeOutput(parameters)
class EmptyLTXVLatentVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -219,7 +262,15 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
ICLoRAParameters.Input(
"iclora_parameters",
optional=True,
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
"Used for adjusting guide processing as required by certain IC-LoRAs "
"(eg. those with a reference_downscale_factor > 1). "
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
@ -229,14 +280,41 @@ class LTXVAddGuide(io.ComfyNode):
)
@classmethod
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
@classmethod
def dilate_latent(cls, guide_latent, latent_downscale_factor):
if latent_downscale_factor <= 1:
return guide_latent, None
scale = int(latent_downscale_factor)
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
dilated[..., ::scale, ::scale] = guide_latent
dilated_mask = torch.full(
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
)
dilated_mask[..., ::scale, ::scale] = 1.0
return dilated, dilated_mask
@classmethod
def get_reference_downscale_factor(cls, iclora_parameters):
if not iclora_parameters:
return 1
try:
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
except (TypeError, ValueError):
factor = 1
return factor
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
time_scale_factor, _, _ = scale_factors
@ -298,7 +376,7 @@ class LTXVAddGuide(io.ComfyNode):
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -318,7 +396,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength,
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype,
device=noise_mask.device,
)
@ -332,13 +410,43 @@ class LTXVAddGuide(io.ComfyNode):
return latent_image, noise_mask
@classmethod
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
scale_factors = vae.downscale_index_formula
latent_image = latent["samples"]
noise_mask = get_noise_mask(latent)
_, _, latent_length, latent_height, latent_width = latent_image.shape
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
if latent_downscale_factor > 1:
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
raise ValueError(
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
)
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
time_scale_factor = scale_factors[0]
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
resolved_frame_idx = frame_idx
if frame_idx < 0:
_, num_keyframes = get_keyframe_idxs(positive)
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
if not causal_fix:
image = torch.cat([image[:1], image], dim=0)
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
if not causal_fix:
t = t[:, :, 1:, :, :]
image = image[1:]
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
guide_mask = None
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
@ -352,11 +460,13 @@ class LTXVAddGuide(io.ComfyNode):
t,
strength,
scale_factors,
guide_mask=guide_mask,
latent_downscale_factor=latent_downscale_factor,
causal_fix=causal_fix,
)
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
@ -776,6 +886,7 @@ class LtxvExtension(ComfyExtension):
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
GetICLoRAParameters,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,

View File

@ -40,23 +40,13 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -95,23 +85,18 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -345,7 +330,7 @@ class FeatherMask(IO.ComfyNode):
for x in range(right):
feather_rate = (x + 1) / right
output[:, :, -x] *= feather_rate
output[:, :, -(x + 1)] *= feather_rate
for y in range(top):
feather_rate = (y + 1) / top
@ -353,7 +338,7 @@ class FeatherMask(IO.ComfyNode):
for y in range(bottom):
feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate
output[:, -(y + 1), :] *= feather_rate
return IO.NodeOutput(output)

View File

@ -63,7 +63,7 @@ class MathExpressionNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
input=io.MultiType.Input("value", [io.Float, io.Int, io.Boolean]),
names=list(string.ascii_lowercase),
min=1,
)
@ -82,6 +82,7 @@ class MathExpressionNode(io.ComfyNode):
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
io.Boolean.Output(display_name="BOOL"),
],
)
@ -97,7 +98,7 @@ class MathExpressionNode(io.ComfyNode):
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
if not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
@ -106,7 +107,7 @@ class MathExpressionNode(io.ComfyNode):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))
return io.NodeOutput(float(result), int(result), bool(result))
class MathExtension(ComfyExtension):

View File

@ -134,8 +134,11 @@ class ModelSamplingSD3:
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
original = m.get_model_object("model_sampling")
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift, multiplier=multiplier)
if hasattr(original, "noise_scale"):
model_sampling.set_noise_scale(original.noise_scale)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
@ -300,6 +303,29 @@ class RescaleCFG:
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )
class ModelNoiseScale:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"noise_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 64.0, "step": 0.01,
"tooltip": "Absolute training noise scale. For example HiDream-O1 base: 8.0, dev: 7.5."}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, noise_scale):
m = model.clone()
original = m.get_model_object("model_sampling")
ms = type(original)(m.model.model_config)
ms.set_parameters(shift=original.shift, multiplier=original.multiplier)
ms.set_noise_scale(noise_scale)
m.add_object_patch("model_sampling", ms)
return (m, )
class ModelComputeDtype:
SEARCH_ALIASES = ["model precision", "change dtype"]
@classmethod
@ -327,6 +353,7 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"ModelNoiseScale": ModelNoiseScale,
"RescaleCFG": RescaleCFG,
"ModelComputeDtype": ModelComputeDtype,
}

406
comfy_extras/nodes_moge.py Normal file
View File

@ -0,0 +1,406 @@
"""ComfyUI nodes for the native MoGe (Monocular Geometry Estimation) integration."""
from __future__ import annotations
import torch
import comfy.utils
import folder_paths
from comfy_api.latest import ComfyExtension, Types, io
from typing_extensions import override
from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
import comfy.model_management
from tqdm.auto import tqdm
MoGeModelType = io.Custom("MOGE_MODEL")
MoGeGeometry = io.Custom("MOGE_GEOMETRY")
# MOGE_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
# "points": torch.Tensor (B, H, W, 3)
# "depth": torch.Tensor (B, H, W)
# "intrinsics": torch.Tensor (B, 3, 3) -- perspective only
# "mask": torch.Tensor (B, H, W) bool
# "normal": torch.Tensor (B, H, W, 3) -- v2 only
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
def _turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
finite = torch.isfinite(points).all(dim=-1)
pts = torch.where(finite.unsqueeze(-1), points, torch.zeros_like(points))
dx = pts[..., :, 2:, :] - pts[..., :, :-2, :]
dy = pts[..., 2:, :, :] - pts[..., :-2, :, :]
dx = torch.nn.functional.pad(dx.permute(0, 3, 1, 2), (1, 1, 0, 0)).permute(0, 2, 3, 1)
dy = torch.nn.functional.pad(dy.permute(0, 3, 1, 2), (0, 0, 1, 1)).permute(0, 2, 3, 1)
# dy x dx (not dx x dy) so the result is outward-facing in OpenCV (Y-down flips the right-hand rule), matching v2's predicted normals.
n = torch.cross(dy, dx, dim=-1)
n = torch.nn.functional.normalize(n, dim=-1)
return torch.where(finite.unsqueeze(-1), n, torch.zeros_like(n))
def _normalize_disparity(depth: torch.Tensor) -> torch.Tensor:
"""Per-batch normalize 1/depth to [0, 1] using 0.1/99.9 percentile clipping."""
out = torch.zeros_like(depth)
for i in range(depth.shape[0]):
d = depth[i]
valid = torch.isfinite(d) & (d > 0)
if not valid.any():
continue
disp = torch.where(valid, 1.0 / d.clamp_min(1e-6), torch.zeros_like(d))
disp_valid = disp[valid]
lo = torch.quantile(disp_valid, 0.001)
hi = torch.quantile(disp_valid, 0.999)
scale = (hi - lo).clamp_min(1e-6)
norm = ((disp - lo) / scale).clamp(0.0, 1.0)
out[i] = torch.where(valid, norm, torch.zeros_like(norm))
return out
class LoadMoGeModel(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadMoGeModel",
display_name="Load MoGe Model",
category="loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("geometry_estimation")),
],
outputs=[MoGeModelType.Output()],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
sd = comfy.utils.load_torch_file(path, safe_load=True)
return io.NodeOutput(MoGeModel(sd))
class MoGePanoramaInference(io.ComfyNode):
"""Equirectangular panorama inference: split into 12 perspective views, run
MoGe at fov_x=90 on each, merge via multi-scale Poisson + gradient solve.
v2's predicted normals and metric scale are ignored (per-view scales would not align across seams).
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGePanoramaInference",
display_name="MoGe Panorama Inference",
category="image/geometry_estimation",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."),
io.Int.Input("resolution_level", default=9, min=0, max=9,
tooltip="Per-view detail (0 = fastest, 9 = most detailed)."),
io.Int.Input("split_resolution", default=512, min=256, max=1024,
tooltip="Resolution of each perspective split."),
io.Int.Input("merge_resolution", default=1920, min=256, max=8192,
tooltip="Long-side resolution of the merged equirect distance map."),
io.Int.Input("batch_size", default=4, min=1, max=12,
tooltip="Views per inference batch (12 splits total)."),
],
outputs=[MoGeGeometry.Output(display_name="moge_geometry")],
)
@classmethod
def execute(cls, moge_model, image, resolution_level, split_resolution, merge_resolution, batch_size) -> io.NodeOutput:
if image.shape[0] != 1:
raise ValueError(f"MoGePanoramaInference takes a single image (got batch of {image.shape[0]})")
image = image[..., :3]
H, W = int(image.shape[1]), int(image.shape[2])
scale = min(merge_resolution / max(H, W), 1.0)
merge_h, merge_w = max(int(H * scale), 32), max(int(W * scale), 32)
extrinsics, intrinsics = get_panorama_cameras()
comfy.model_management.load_model_gpu(moge_model.patcher)
device = moge_model.load_device
img_chw = image[0].movedim(-1, -3).to(device=device, dtype=moge_model.dtype)
splits = split_panorama_image(img_chw, extrinsics, intrinsics, split_resolution)
n_views = splits.shape[0]
# Weight each lsmr solve by 4^level so the final-resolution solve doesn't leave the bar idle.
merge_levels: list[tuple[int, int]] = []
w_, h_ = merge_w, merge_h
while True:
merge_levels.append((w_, h_))
if max(w_, h_) <= 256:
break
w_, h_ = w_ // 2, h_ // 2
merge_levels.reverse()
solve_weight = {wh: 4 ** i for i, wh in enumerate(merge_levels)}
n_merge_view_units = n_views * len(merge_levels)
n_merge_solve_units = sum(solve_weight.values())
pbar = comfy.utils.ProgressBar(n_views + n_merge_view_units + n_merge_solve_units)
done = 0
distance_maps: list = []
masks: list = []
with tqdm(total=n_views, desc="MoGe panorama inference") as tq:
for i in range(0, n_views, batch_size):
batch = splits[i:i + batch_size]
# apply_metric_scale=False: per-view scales would not align across overlap seams.
result = moge_model.infer(batch, resolution_level=resolution_level,
fov_x=90.0, force_projection=True,
apply_mask=False, apply_metric_scale=False)
distance_maps.extend(list(result["points"].float().norm(dim=-1).cpu().numpy()))
masks.extend(list(result["mask"].cpu().numpy()))
n = batch.shape[0]
done += n
pbar.update_absolute(done)
tq.update(n)
with tqdm(total=n_merge_view_units + n_merge_solve_units, desc="MoGe panorama merge: views") as tq:
def _on_merge_view():
nonlocal done
done += 1
pbar.update_absolute(done)
tq.update(1)
def _on_solve_start(w, h):
tq.set_description(f"MoGe panorama merge: solving {w}x{h}")
def _on_solve_end(w, h):
nonlocal done
weight = solve_weight[(w, h)]
done += weight
pbar.update_absolute(done)
tq.update(weight)
tq.set_description("MoGe panorama merge: views")
pano_depth, pano_mask = merge_panorama_depth(
merge_w, merge_h, distance_maps, masks, list(extrinsics), intrinsics,
on_view=_on_merge_view, on_solve_start=_on_solve_start, on_solve_end=_on_solve_end)
pano_depth = torch.from_numpy(pano_depth)
pano_mask = torch.from_numpy(pano_mask)
if (merge_h, merge_w) != (H, W):
pano_depth = torch.nn.functional.interpolate(pano_depth[None, None], size=(H, W), mode="bilinear", align_corners=False).squeeze()
pano_mask = torch.nn.functional.interpolate(pano_mask[None, None].float(), size=(H, W), mode="nearest").squeeze() > 0
# Pixels uncovered by any view's predicted foreground are unconstrained in the lsmr solve and stay at log_depth=0 (depth=1)
if pano_mask.any() and not pano_mask.all():
far = torch.quantile(pano_depth[pano_mask], 0.95) * 5.0
pano_depth = torch.where(pano_mask, pano_depth, far)
directions = torch.from_numpy(spherical_uv_to_directions(_uv_grid(H, W)))
points = (directions * pano_depth[..., None]).unsqueeze(0)
depth = pano_depth.unsqueeze(0)
mask = pano_mask.unsqueeze(0)
# Points stay in MoGe spherical coords; MoGePointMapToMesh applies the spherical->glTF rotation after triangulation
moge_geometry = {"points": points, "depth": depth, "mask": mask, "image": image.cpu()}
return io.NodeOutput(moge_geometry)
class MoGeInference(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeInference",
display_name="MoGe Inference",
category="image/geometry_estimation",
inputs=[
MoGeModelType.Input("moge_model"),
io.Image.Input("image"),
io.Int.Input("resolution_level", default=9, min=0, max=9,
tooltip="0 = fastest, 9 = most detail."),
io.Float.Input("fov_x_degrees", default=0.0, min=0.0, max=170.0, step=0.1, advanced=True,
tooltip="Horizontal field of view of the source camera. Sets the focal length used to unproject the depth map into 3D. 0 = auto-recover from the predicted points."),
io.Int.Input("batch_size", default=4, min=1, max=64,
tooltip="Images per inference call. Lower if you OOM on a long video / image set."),
io.Boolean.Input("force_projection", default=True, advanced=True),
io.Boolean.Input("apply_mask", default=True, advanced=True,
tooltip="Set masked-out (sky / invalid) pixels to inf in points and depth so meshing culls them. Disable to keep the raw predicted geometry everywhere; the mask is still returned separately."),
],
outputs=[MoGeGeometry.Output(display_name="moge_geometry")],
)
@classmethod
def execute(cls, moge_model, image, resolution_level, fov_x_degrees, batch_size, force_projection, apply_mask) -> io.NodeOutput:
image = image[..., :3]
bchw = image.movedim(-1, -3).contiguous()
B = bchw.shape[0]
fov = None if fov_x_degrees <= 0 else float(fov_x_degrees)
pbar = comfy.utils.ProgressBar(B)
chunks: list[dict] = []
with tqdm(total=B, desc="MoGe inference") as tq:
for i in range(0, B, batch_size):
chunk = bchw[i:i + batch_size]
chunks.append(moge_model.infer(chunk, resolution_level=resolution_level, fov_x=fov,
force_projection=force_projection, apply_mask=apply_mask))
pbar.update_absolute(min(i + batch_size, B))
tq.update(chunk.shape[0])
def stack(field):
vals = [c[field] for c in chunks if field in c]
return torch.cat(vals, dim=0) if vals else None
moge_geometry = {"image": image.cpu()}
for field in ("points", "depth", "intrinsics", "mask", "normal"):
v = stack(field)
if v is not None:
moge_geometry[field] = v
return io.NodeOutput(moge_geometry)
class MoGeRender(io.ComfyNode):
"""Render a visualization or mask from a MOGE_GEOMETRY packet."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeRender",
display_name="MoGe Render",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Combo.Input("output", options=["depth", "depth_colored", "normal_opengl", "normal_directx", "mask"], default="depth",
tooltip="DirectX vs OpenGL controls the normal-map green-channel convention. DirectX: green = -Y down (Unreal). OpenGL: green = +Y up (Blender, Substance, Unity, glTF)."),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, moge_geometry, output) -> io.NodeOutput:
is_normal = output in ("normal_directx", "normal_opengl")
opengl = output.endswith("_opengl")
# Pick the input tensor for the chosen mode and validate availability.
if output in ("depth", "depth_colored"):
if "depth" not in moge_geometry:
raise ValueError("moge_geometry has no depth output.")
src = moge_geometry["depth"]
elif is_normal:
if "normal" in moge_geometry:
src = moge_geometry["normal"]
elif "points" in moge_geometry:
src = moge_geometry["points"]
else:
raise ValueError("moge_geometry has neither normals nor points to derive normals from.")
elif output == "mask":
if "mask" not in moge_geometry:
raise ValueError("moge_geometry has no mask output.")
src = moge_geometry["mask"]
else:
raise ValueError(f"Unknown output mode: {output}")
B = src.shape[0]
pbar = comfy.utils.ProgressBar(B)
out: list[torch.Tensor] = []
with tqdm(total=B, desc=f"MoGe render: {output}") as tq:
for i in range(B):
slc = src[i:i + 1].float()
if output in ("depth", "depth_colored"):
d = _normalize_disparity(slc)
out.append(_turbo(d) if output == "depth_colored"
else d.unsqueeze(-1).expand(*d.shape, 3).contiguous())
elif is_normal:
n = slc if "normal" in moge_geometry else _normals_from_points(slc)
# MoGe is OpenCV (Z+ into scene); normal-map convention is Z+ out of surface, so flip Z.
y_sign = -1.0 if opengl else 1.0
n = n * n.new_tensor([1.0, y_sign, -1.0])
out.append((n * 0.5 + 0.5).clamp(0.0, 1.0))
elif output == "mask":
out.append(slc.unsqueeze(-1).expand(*slc.shape, 3).contiguous())
pbar.update_absolute(i + 1)
tq.update(1)
result = torch.cat(out, dim=0).to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(result)
class MoGePointMapToMesh(io.ComfyNode):
"""Triangulate one image of a MoGe point map into a Types.MESH (UVs + texture)."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGePointMapToMesh",
display_name="MoGe Point Map to Mesh",
category="image/geometry_estimation",
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which image of a batched MoGe geometry to mesh. Per-image vertex counts "
"differ, so batches can't be stacked into a single MESH."),
io.Int.Input("decimation", default=1, min=1, max=8,
tooltip="Vertex stride; 1 = full resolution."),
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01,
tooltip="Drop pixels whose 3x3 depth span exceeds this fraction. 0 = off."),
io.Boolean.Input("texture", default=True,
tooltip="Carry the source image through as the baseColor texture."),
],
outputs=[io.Mesh.Output()],
)
@classmethod
def execute(cls, moge_geometry, batch_index, decimation, discontinuity_threshold, texture) -> io.NodeOutput:
if "points" not in moge_geometry:
raise ValueError("moge_geometry has no points output.")
points = moge_geometry["points"]
B = points.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} out of range; moge_geometry has batch size {B}.")
# Pass depth so the rtol edge check sees radial depth -- for panoramas
# points[..., 2] = cos(phi)*r goes negative below the equator and the rtol clamp would drop the bottom half.
edge_depth = moge_geometry["depth"][batch_index] if "depth" in moge_geometry else None
verts, faces, uvs = triangulate_grid_mesh(
points[batch_index], decimation=decimation,
discontinuity_threshold=discontinuity_threshold, depth=edge_depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError("MoGe produced an empty mesh; try discontinuity_threshold=0 or apply_mask=False.")
if "intrinsics" not in moge_geometry:
# Panorama: rotate MoGe spherical (Z up) -> glTF (Y up, Z back), correct for inside-the-sphere viewing)
verts = verts[:, [1, 2, 0]].contiguous()
else:
# Perspective MoGe (X right, Y down, Z forward) -> glTF; face flip keeps winding CCW after the Y/Z flip.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = moge_geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),
uvs=uvs.unsqueeze(0),
texture=tex,
)
return io.NodeOutput(mesh)
class MoGeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
async def comfy_entrypoint() -> MoGeExtension:
return MoGeExtension()

View File

@ -116,7 +116,7 @@ class EmptyQwenImageLayeredLatentImage(io.ComfyNode):
inputs=[
io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1, advanced=True),
io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[

View File

@ -0,0 +1,396 @@
"""Save-side 3D nodes: mesh packing/slicing helpers + GLB writer + SaveGLB node."""
import json
import logging
import os
import struct
from io import BytesIO
import numpy as np
from PIL import Image
import torch
from typing_extensions import override
import folder_paths
from comfy.cli_args import args
from comfy_api.latest import ComfyExtension, IO, Types
def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None):
# Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors,
# stashing per-item lengths as runtime attrs so consumers can recover the real slice.
# colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts.
# texture is (B, H, W, 3) — passed through unchanged
batch_size = len(vertices)
max_vertices = max(v.shape[0] for v in vertices)
max_faces = max(f.shape[0] for f in faces)
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
for i, (v, f) in enumerate(zip(vertices, faces)):
packed_vertices[i, :v.shape[0]] = v
packed_faces[i, :f.shape[0]] = f
packed_colors = None
if colors is not None:
packed_colors = colors[0].new_zeros((batch_size, max_vertices, colors[0].shape[1]))
for i, c in enumerate(colors):
assert c.shape[0] == vertices[i].shape[0], (
f"vertex_colors[{i}] has {c.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)"
)
packed_colors[i, :c.shape[0]] = c
packed_uvs = None
if uvs is not None:
packed_uvs = uvs[0].new_zeros((batch_size, max_vertices, uvs[0].shape[1]))
for i, u in enumerate(uvs):
assert u.shape[0] == vertices[i].shape[0], (
f"uvs[{i}] has {u.shape[0]} entries, expected {vertices[i].shape[0]} (1:1 with vertices)"
)
packed_uvs[i, :u.shape[0]] = u
return Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
vertex_counts=vertex_counts, face_counts=face_counts)
def get_mesh_batch_item(mesh, index):
# Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths
# if the mesh carries per-item counts (variable-size batch).
v_colors = getattr(mesh, "vertex_colors", None)
v_uvs = getattr(mesh, "uvs", None)
if getattr(mesh, "vertex_counts", None) is not None:
vertex_count = int(mesh.vertex_counts[index].item())
face_count = int(mesh.face_counts[index].item())
vertices = mesh.vertices[index, :vertex_count]
faces = mesh.faces[index, :face_count]
colors = v_colors[index, :vertex_count] if v_colors is not None else None
uvs = v_uvs[index, :vertex_count] if v_uvs is not None else None
return vertices, faces, colors, uvs
colors = v_colors[index] if v_colors is not None else None
uvs = v_uvs[index] if v_uvs is not None else None
return mesh.vertices[index], mesh.faces[index], colors, uvs
def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
Parameters:
vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
filepath: str - Output filepath (should end with .glb)
metadata: dict - Optional asset.extras metadata
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
"""
# Convert tensors to numpy arrays
vertices_np = vertices.cpu().numpy().astype(np.float32)
faces_signed = faces.cpu().numpy().astype(np.int64)
uvs_np = uvs.cpu().numpy().astype(np.float32) if uvs is not None else None
colors_np = vertex_colors.cpu().numpy().astype(np.float32) if vertex_colors is not None else None
if colors_np is not None:
colors_np = np.clip(colors_np, 0.0, 1.0)
n_verts = vertices_np.shape[0]
if n_verts == 0:
raise ValueError("save_glb: vertices is empty")
if faces_signed.size > 0:
fmin = int(faces_signed.min())
fmax = int(faces_signed.max())
if fmin < 0 or fmax >= n_verts:
raise ValueError(
f"save_glb: face index out of range [0, {n_verts}): min={fmin}, max={fmax}"
)
if uvs_np is not None and uvs_np.shape[0] != n_verts:
raise ValueError(
f"save_glb: uvs has {uvs_np.shape[0]} entries but vertex count is {n_verts}"
)
if colors_np is not None and colors_np.shape[0] != n_verts:
raise ValueError(
f"save_glb: vertex_colors has {colors_np.shape[0]} entries but vertex count is {n_verts}"
)
faces_np = faces_signed.astype(np.uint32)
texture_png_bytes = None
if texture_image is not None:
buf = BytesIO()
texture_image.save(buf, format="PNG")
texture_png_bytes = buf.getvalue()
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b'\x00' * padding_length
vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
indices_buffer_padded = pad_to_4_bytes(indices_buffer)
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
buffer_data = b"".join([
vertices_buffer_padded,
indices_buffer_padded,
uvs_buffer_padded,
colors_buffer_padded,
texture_buffer_padded,
])
vertices_byte_length = len(vertices_buffer)
vertices_byte_offset = 0
indices_byte_length = len(indices_buffer)
indices_byte_offset = len(vertices_buffer_padded)
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
buffer_views = [
{
"buffer": 0,
"byteOffset": vertices_byte_offset,
"byteLength": vertices_byte_length,
"target": 34962 # ARRAY_BUFFER
},
{
"buffer": 0,
"byteOffset": indices_byte_offset,
"byteLength": indices_byte_length,
"target": 34963 # ELEMENT_ARRAY_BUFFER
}
]
accessors = [
{
"bufferView": 0,
"byteOffset": 0,
"componentType": 5126, # FLOAT
"count": len(vertices_np),
"type": "VEC3",
"max": vertices_np.max(axis=0).tolist(),
"min": vertices_np.min(axis=0).tolist()
},
{
"bufferView": 1,
"byteOffset": 0,
"componentType": 5125, # UNSIGNED_INT
"count": faces_np.size,
"type": "SCALAR"
}
]
primitive_attributes = {"POSITION": 0}
if uvs_np is not None and len(uvs_np) > 0:
buffer_views.append({
"buffer": 0,
"byteOffset": uvs_byte_offset,
"byteLength": len(uvs_buffer),
"target": 34962
})
accessor_idx = len(accessors)
accessors.append({
"bufferView": len(buffer_views) - 1,
"byteOffset": 0,
"componentType": 5126,
"count": len(uvs_np),
"type": "VEC2",
})
primitive_attributes["TEXCOORD_0"] = accessor_idx
if colors_np is not None and len(colors_np) > 0:
buffer_views.append({
"buffer": 0,
"byteOffset": colors_byte_offset,
"byteLength": len(colors_buffer),
"target": 34962
})
accessor_idx = len(accessors)
accessors.append({
"bufferView": len(buffer_views) - 1,
"byteOffset": 0,
"componentType": 5126,
"count": len(colors_np),
"type": "VEC3" if colors_np.shape[1] == 3 else "VEC4",
})
primitive_attributes["COLOR_0"] = accessor_idx
primitive = {
"attributes": primitive_attributes,
"indices": 1,
"mode": 4 # TRIANGLES
}
images = []
textures = []
samplers = []
materials = []
if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
"byteOffset": texture_byte_offset,
"byteLength": len(texture_buffer),
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": 0, "sampler": 0})
materials.append({
"pbrMetallicRoughness": {
"baseColorTexture": {"index": 0, "texCoord": 0},
"metallicFactor": 0.0,
"roughnessFactor": 1.0,
},
"doubleSided": True,
})
primitive["material"] = 0
gltf = {
"asset": {"version": "2.0", "generator": "ComfyUI"},
"buffers": [{"byteLength": len(buffer_data)}],
"bufferViews": buffer_views,
"accessors": accessors,
"meshes": [{"primitives": [primitive]}],
"nodes": [{"mesh": 0}],
"scenes": [{"nodes": [0]}],
"scene": 0,
}
if images:
gltf["images"] = images
if samplers:
gltf["samplers"] = samplers
if textures:
gltf["textures"] = textures
if materials:
gltf["materials"] = materials
if metadata:
gltf["asset"]["extras"] = metadata
# Convert the JSON to bytes
gltf_json = json.dumps(gltf).encode('utf8')
def pad_json_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
return buffer + b' ' * padding_length
gltf_json_padded = pad_json_to_4_bytes(gltf_json)
# Create the GLB header (a 4-byte ASCII magic identifier glTF)
glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
# Create JSON chunk header (chunk type 0)
json_chunk_header = struct.pack('<II', len(gltf_json_padded), 0x4E4F534A) # "JSON" in little endian
# Create BIN chunk header (chunk type 1)
bin_chunk_header = struct.pack('<II', len(buffer_data), 0x004E4942) # "BIN\0" in little endian
# Write the GLB file
with open(filepath, 'wb') as f:
f.write(glb_header)
f.write(json_chunk_header)
f.write(gltf_json_padded)
f.write(bin_chunk_header)
f.write(buffer_data)
return filepath
class SaveGLB(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveGLB",
display_name="Save 3D Model",
search_aliases=["export 3d model", "save mesh"],
category="3d",
essentials_category="Basics",
is_output_node=True,
inputs=[
IO.MultiType.Input(
IO.Mesh.Input("mesh"),
types=[
IO.File3DGLB,
IO.File3DGLTF,
IO.File3DOBJ,
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DAny,
],
tooltip="Mesh or 3D file to save",
),
IO.String.Input("filename_prefix", default="3d/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
@classmethod
def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
metadata = {}
if not args.disable_metadata:
if cls.hidden.prompt is not None:
metadata["prompt"] = json.dumps(cls.hidden.prompt)
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
if isinstance(mesh, Types.File3D):
# Handle File3D input - save BytesIO data to output folder
ext = mesh.format or "glb"
f = f"{filename}_{counter:05}_.{ext}"
mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
else:
# Handle Mesh input - save vertices and faces as GLB; carry optional UVs / colors / texture.
texture_b = getattr(mesh, "texture", None)
texture_np = None
if texture_b is not None:
texture_np = (texture_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
)
for i in range(mesh.vertices.shape[0]):
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
continue
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
f = f"{filename}_{counter:05}_.glb"
save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata,
uvs=uvs_i,
vertex_colors=v_colors,
texture_image=tex_img)
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return IO.NodeOutput(ui={"3d": results})
class Save3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [SaveGLB]
async def comfy_entrypoint() -> Save3DExtension:
return Save3DExtension()

View File

@ -123,6 +123,7 @@ class CreateVideo(io.ComfyNode):
search_aliases=["images to video"],
display_name="Create Video",
category="video",
essentials_category="Video Tools",
description="Create a video from images.",
inputs=[
io.Image.Input("images", tooltip="The images to create a video from."),

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.20.1"
__version__ = "0.21.1"

View File

@ -626,7 +626,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.info("Memory summary:\n{}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:

View File

@ -56,6 +56,8 @@ folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "backg
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geometry_estimation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")

View File

@ -700,17 +700,19 @@ class LoraLoader:
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
lora = None
lora_metadata = None
if self.loaded_lora is not None:
if self.loaded_lora[0] == lora_path:
lora = self.loaded_lora[1]
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
else:
self.loaded_lora = None
if lora is None:
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
self.loaded_lora = (lora_path, lora)
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
self.loaded_lora = (lora_path, lora, lora_metadata)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader):
@ -1221,7 +1223,7 @@ class LatentFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 63}),
"batch_index": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION}),
"length": ("INT", {"default": 1, "min": 1, "max": 64}),
}}
RETURN_TYPES = ("LATENT",)
@ -1232,7 +1234,9 @@ class LatentFromBatch:
def frombatch(self, samples, batch_index, length):
s = samples.copy()
s_in = samples["samples"]
batch_index = min(s_in.shape[0] - 1, batch_index)
if batch_index < 0:
batch_index += s_in.shape[0]
batch_index = max(0, min(s_in.shape[0] - 1, batch_index))
length = min(s_in.shape[0] - batch_index, length)
s["samples"] = s_in[batch_index:batch_index + length].clone()
if "noise_mask" in samples:
@ -2435,6 +2439,9 @@ async def init_builtin_extra_nodes():
"nodes_sam3.py",
"nodes_void.py",
"nodes_wandancer.py",
"nodes_hidream_o1.py",
"nodes_save_3d.py",
"nodes_moge.py",
]
import_failed = []

View File

@ -2071,7 +2071,6 @@ paths:
type: integer
description: Number of assets marked as missing
# ===========================================================================
# Cloud-runtime FE-facing operations
#
@ -2122,7 +2121,11 @@ paths:
operationId: getCloudJobStatus
tags: [queue]
summary: Get status of a cloud job
description: "[cloud-only] Returns the current execution status of a cloud job."
deprecated: true
description: |
**Deprecated.** This endpoint is superseded by `GET /api/jobs/{job_id}`.
Clients should migrate; the endpoint is retained for backward
compatibility but will be removed in a future release.
x-runtime: [cloud]
parameters:
- name: job_id
@ -2192,7 +2195,11 @@ paths:
operationId: getHistoryV2
tags: [history]
summary: Get paginated execution history (v2)
description: "[cloud-only] Returns a paginated list of execution history entries in the v2 format, with richer metadata than the legacy history endpoint."
deprecated: true
description: |
**Deprecated.** This endpoint is superseded by `GET /api/jobs`.
Clients should migrate; the endpoint is retained for backward
compatibility but will be removed in a future release.
x-runtime: [cloud]
parameters:
- name: limit
@ -2231,7 +2238,11 @@ paths:
operationId: getHistoryV2ByPromptId
tags: [history]
summary: Get v2 history for a specific prompt
description: "[cloud-only] Returns the v2 history entry for a specific prompt execution."
deprecated: true
description: |
**Deprecated.** This endpoint is superseded by `GET /api/jobs/{prompt_id}`.
Clients should migrate; the endpoint is retained for backward
compatibility but will be removed in a future release.
x-runtime: [cloud]
parameters:
- name: prompt_id
@ -2266,7 +2277,12 @@ paths:
operationId: getCloudLogs
tags: [system]
summary: Get cloud execution logs
description: "[cloud-only] Returns execution logs for the authenticated user's cloud jobs."
deprecated: true
description: |
**Deprecated.** This endpoint returns a static placeholder response and
provides no real log data. It is retained only to avoid breaking clients
that still call it. Clients should remove their dependency; the endpoint
will be removed in a future release.
x-runtime: [cloud]
parameters:
- name: job_id
@ -5370,7 +5386,12 @@ paths:
operationId: viewVideo
tags: [view]
summary: View or download a video file
description: "[cloud-only] Serves a video file from the output directory. Used by the frontend video player."
deprecated: true
description: |
**Deprecated.** This endpoint is an alias of `GET /api/view` added for
legacy history-queue video playback. Callers should use `/api/view`
directly; the endpoint is retained for backward compatibility but will
be removed in a future release.
x-runtime: [cloud]
parameters:
- name: filename
@ -5523,7 +5544,6 @@ paths:
schema:
$ref: "#/components/schemas/CloudError"
components:
parameters:
ComfyUserHeader:
@ -6010,6 +6030,24 @@ components:
type: string
nullable: true
description: Minimum required workflow templates version for this ComfyUI build
comfy_package_versions:
type: array
description: Installed and required versions for every comfy* package pinned in requirements.txt
items:
type: object
required:
- name
- installed
- required
properties:
name:
type: string
installed:
type: string
nullable: true
required:
type: string
nullable: true
devices:
type: array
items:
@ -6875,7 +6913,6 @@ components:
error:
type: string
# -------------------------------------------------------------------
# Cloud-runtime schemas
#

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.20.1"
version = "0.21.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.72
comfyui-embedded-docs==0.4.4
comfyui-workflow-templates==0.9.77
comfyui-embedded-docs==0.5.0
torch
torchsde
torchvision

View File

@ -656,6 +656,7 @@ class PromptServer():
required_frontend_version = FrontendManager.get_required_frontend_version()
installed_templates_version = FrontendManager.get_installed_templates_version()
required_templates_version = FrontendManager.get_required_templates_version()
comfy_package_versions = FrontendManager.get_comfy_package_versions()
system_stats = {
"system": {
@ -666,6 +667,7 @@ class PromptServer():
"required_frontend_version": required_frontend_version,
"installed_templates_version": installed_templates_version,
"required_templates_version": required_templates_version,
"comfy_package_versions": comfy_package_versions,
"python_version": sys.version,
"pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",

View File

@ -52,7 +52,10 @@ def mock_provider(mock_releases):
@pytest.fixture(autouse=True)
def clear_cache():
import utils.install_util
import app.frontend_management
utils.install_util.PACKAGE_VERSIONS = {}
app.frontend_management.COMFY_PACKAGE_VERSIONS = []
def test_get_release(mock_provider, mock_releases):
@ -147,7 +150,7 @@ def test_init_frontend_default_with_mocks():
# Act
with (
patch("app.frontend_management.check_frontend_version") as mock_check,
patch("app.frontend_management.check_comfy_packages_versions") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/mocked/path"
),
@ -168,7 +171,7 @@ def test_init_frontend_fallback_on_error():
patch.object(
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
),
patch("app.frontend_management.check_frontend_version") as mock_check,
patch("app.frontend_management.check_comfy_packages_versions") as mock_check,
patch.object(
FrontendManager, "default_frontend_path", return_value="/default/path"
),
@ -277,7 +280,9 @@ def test_get_installed_templates_version():
def test_get_installed_templates_version_not_installed():
# Act
with patch("app.frontend_management.version", side_effect=Exception("Package not found")):
with patch(
"app.frontend_management.version", side_effect=Exception("Package not found")
):
version = FrontendManager.get_installed_templates_version()
# Assert

View File

@ -124,9 +124,11 @@ class TestMathExpressionExecute:
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)
def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)
def test_boolean_result(self):
result = self._exec("a > b", a=5, b=3)
assert result[2] is True
result = self._exec("a > b", a=3, b=5)
assert result[2] is False
def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):

View File

@ -1,9 +1,23 @@
from collections import defaultdict
import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
import comfy.supported_models
def _freeze(value):
"""Recursively convert a value to a hashable form so configs can be
compared/used as dict keys or set members."""
if isinstance(value, dict):
return frozenset((k, _freeze(v)) for k, v in value.items())
if isinstance(value, (list, tuple)):
return tuple(_freeze(v) for v in value)
if isinstance(value, set):
return frozenset(_freeze(v) for v in value)
return value
def _make_longcat_comfyui_sd():
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
sd = {}
@ -110,3 +124,21 @@ class TestModelDetection:
model_config = model_config_from_unet_config(unet_config, sd)
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same
combination, ``BASE.matches`` cannot disambiguate between them and the
first one in the list will always win."""
models = comfy.supported_models.models
groups = defaultdict(list)
for model in models:
key = (_freeze(model.unet_config), _freeze(model.required_keys))
groups[key].append(model.__name__)
duplicates = {k: names for k, names in groups.items() if len(names) > 1}
assert not duplicates, (
"Found models sharing the same (unet_config, required_keys) "
"combination, which makes detection ambiguous: "
+ "; ".join(", ".join(names) for names in duplicates.values())
)