mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Compare commits
34 Commits
a7ffbf5eaa
...
43abcfed7d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
43abcfed7d | ||
|
|
8b08bfdcbe | ||
|
|
4e823431cc | ||
|
|
66669b2ded | ||
|
|
65045730a6 | ||
|
|
87878f354f | ||
|
|
c5ecd231a2 | ||
|
|
9864f5ac86 | ||
|
|
05cd076bc1 | ||
|
|
d3c18c1636 | ||
|
|
bac6fc35fb | ||
|
|
56c74094c7 | ||
|
|
594de378fe | ||
|
|
c8673542f7 | ||
|
|
df7bf1d3dc | ||
|
|
ef8f25601a | ||
|
|
8dc3f3f209 | ||
|
|
c011fb520c | ||
|
|
c945a433ae | ||
|
|
25757a53c9 | ||
|
|
1b25f1289e | ||
|
|
e35348aa53 | ||
|
|
cd8c7a2306 | ||
|
|
6bcd8b96ab | ||
|
|
9c34f5f36a | ||
|
|
78b3096bf3 | ||
|
|
2b63add0ad | ||
|
|
160b95f75c | ||
|
|
c168960a12 | ||
|
|
e5369c0eec | ||
|
|
1655f8089a | ||
|
|
89014792c9 | ||
|
|
431fadb520 | ||
|
|
1ac60da2c9 |
31
.github/workflows/openapi-lint.yml
vendored
Normal file
31
.github/workflows/openapi-lint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: OpenAPI Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
- '.spectral.yaml'
|
||||
- '.github/workflows/openapi-lint.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
spectral:
|
||||
name: Run Spectral
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install Spectral
|
||||
run: npm install -g @stoplight/spectral-cli@6
|
||||
|
||||
- name: Lint openapi.yaml
|
||||
run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -145,6 +145,8 @@ jobs:
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
echo 'local-portable' > ComfyUI/.comfy_environment
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
|
||||
91
.spectral.yaml
Normal file
91
.spectral.yaml
Normal file
@ -0,0 +1,91 @@
|
||||
extends:
|
||||
- spectral:oas
|
||||
|
||||
# Severity levels: error, warn, info, hint, off
|
||||
# Rules from the built-in "spectral:oas" ruleset are active by default.
|
||||
# Below we tune severity and add custom rules for our conventions.
|
||||
#
|
||||
# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
|
||||
# organization are linted against a single consistent standard.
|
||||
|
||||
rules:
|
||||
# -----------------------------------------------------------------------
|
||||
# Built-in rule severity overrides
|
||||
# -----------------------------------------------------------------------
|
||||
operation-operationId: error
|
||||
operation-description: warn
|
||||
operation-tag-defined: error
|
||||
info-contact: off
|
||||
info-description: warn
|
||||
no-eval-in-markdown: error
|
||||
no-$ref-siblings: error
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: naming conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Property names should be snake_case
|
||||
property-name-snake-case:
|
||||
description: Property names must be snake_case
|
||||
severity: warn
|
||||
given: "$.components.schemas.*.properties[*]~"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
|
||||
|
||||
# Operation IDs should be camelCase
|
||||
operation-id-camel-case:
|
||||
description: Operation IDs must be camelCase
|
||||
severity: warn
|
||||
given: "$.paths.*.*.operationId"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-zA-Z0-9]*$"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: response conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Error responses (4xx, 5xx) should use a consistent shape
|
||||
error-response-schema:
|
||||
description: Error responses should reference a standard error schema
|
||||
severity: hint
|
||||
given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
|
||||
then:
|
||||
field: "$ref"
|
||||
function: truthy
|
||||
|
||||
# All 2xx responses with JSON body should have a schema
|
||||
response-schema-defined:
|
||||
description: Success responses with JSON content should define a schema
|
||||
severity: warn
|
||||
given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
|
||||
then:
|
||||
field: schema
|
||||
function: truthy
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: best practices
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Path parameters must have a description
|
||||
path-param-description:
|
||||
description: Path parameters should have a description
|
||||
severity: warn
|
||||
given:
|
||||
- "$.paths.*.parameters[?(@.in == 'path')]"
|
||||
- "$.paths.*.*.parameters[?(@.in == 'path')]"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
|
||||
# Schemas should have a description
|
||||
schema-description:
|
||||
description: Component schemas should have a description
|
||||
severity: hint
|
||||
given: "$.components.schemas.*"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
@ -27,7 +27,7 @@ def frontend_install_warning_message():
|
||||
return f"""
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
"""Register a node replacement mapping.
|
||||
|
||||
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||
is already registered, the duplicate is ignored. This prevents stale
|
||||
entries from accumulating when custom nodes are reloaded in the same
|
||||
process (e.g. via ComfyUI-Manager).
|
||||
"""
|
||||
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||
for entry in existing:
|
||||
if entry.new_node_id == node_replace.new_node_id:
|
||||
logging.debug(
|
||||
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||
node_replace.old_node_id, node_replace.new_node_id,
|
||||
)
|
||||
return
|
||||
existing.append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
|
||||
@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
"modified": int(os.path.getmtime(path) * 1000),
|
||||
"created": int(os.path.getctime(path) * 1000),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -431,9 +431,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -162,7 +162,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Image (Z-Image-Turbo)",
|
||||
"name": "Canny to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1553,7 +1553,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Canny to image"
|
||||
"category": "Image generation and editing/Canny to image",
|
||||
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1574,4 +1575,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -192,7 +192,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Video (LTX 2.0)",
|
||||
"name": "Canny to Video (LTX 2.0)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -3600,7 +3600,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Canny to video"
|
||||
"category": "Video generation and editing/Canny to video",
|
||||
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -3616,4 +3617,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -377,8 +377,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -596,7 +596,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1129,7 +1129,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -608,7 +608,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1609,7 +1609,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 2×2 grid of four equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -2946,7 +2946,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 3×3 grid of nine equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1579,7 +1579,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
|
||||
},
|
||||
{
|
||||
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
|
||||
@ -2461,7 +2462,8 @@
|
||||
]
|
||||
},
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4233,7 +4233,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Depth to video"
|
||||
"category": "Video generation and editing/Depth to video",
|
||||
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
|
||||
},
|
||||
{
|
||||
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
|
||||
@ -5192,7 +5193,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -450,9 +450,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -580,8 +580,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3350,7 +3350,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/First-Last-Frame to Video"
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -575,8 +575,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -752,8 +752,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -374,7 +374,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -310,7 +310,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Image Captioning"
|
||||
"category": "Text generation/Image Captioning",
|
||||
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -315,8 +315,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2138,7 +2138,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1472,7 +1472,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
|
||||
},
|
||||
{
|
||||
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
|
||||
@ -1821,7 +1822,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1837,4 +1839,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
|
||||
@ -1417,7 +1417,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -132,7 +132,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Edit (Qwen 2511)",
|
||||
"name": "Image Edit (Qwen 2511)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1468,7 +1468,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1489,4 +1490,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1188,7 +1188,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1202,4 +1203,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1548,7 +1548,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
|
||||
},
|
||||
{
|
||||
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
|
||||
@ -1907,7 +1908,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -742,9 +742,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -1919,7 +1919,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Outpaint image"
|
||||
"category": "Image generation and editing/Outpaint image",
|
||||
"description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
|
||||
},
|
||||
{
|
||||
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
|
||||
@ -2278,7 +2279,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
},
|
||||
{
|
||||
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
|
||||
@ -2733,7 +2735,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Scales both image and mask together while preserving alignment for editing workflows."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -141,7 +141,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Upscale(Z-image-Turbo)",
|
||||
"name": "Image Upscale (Z-image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1302,7 +1302,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Enhance"
|
||||
"category": "Image generation and editing/Enhance",
|
||||
"description": "Upscales images to higher resolution using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -99,7 +99,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Depth Map (Lotus)",
|
||||
"name": "Image to Depth Map (Lotus)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -948,7 +948,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -964,4 +965,4 @@
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1586,7 +1586,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Image to layers"
|
||||
"category": "Image generation and editing/Image to layers",
|
||||
"description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -72,7 +72,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Model (Hunyuan3d 2.1)",
|
||||
"name": "Image to 3D Model (Hunyuan3d 2.1)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -765,7 +765,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "3D/Image to 3D Model"
|
||||
"category": "3D/Image to 3D Model",
|
||||
"description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4223,7 +4223,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from a single input image using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -206,7 +206,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Video (Wan 2.2)",
|
||||
"name": "Image to Video (Wan 2.2)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2027,7 +2027,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -134,7 +134,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Pose to Image (Z-Image-Turbo)",
|
||||
"name": "Pose to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1298,7 +1298,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Pose to image"
|
||||
"category": "Image generation and editing/Pose to image",
|
||||
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1319,4 +1320,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -3870,7 +3870,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Pose to video"
|
||||
"category": "Video generation and editing/Pose to video",
|
||||
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -270,9 +270,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Prompt enhance"
|
||||
"category": "Text generation/Prompt enhance",
|
||||
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -302,8 +302,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +222,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Audio (ACE-Step 1.5)",
|
||||
"name": "Text to Audio (ACE-Step 1.5)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1502,7 +1502,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Audio/Music generation"
|
||||
"category": "Audio/Music generation",
|
||||
"description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1518,4 +1519,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1029,7 +1029,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1043,4 +1044,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1023,7 +1023,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1037,4 +1038,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1104,7 +1104,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
|
||||
},
|
||||
{
|
||||
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
|
||||
@ -1458,11 +1459,12 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1941,7 +1941,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1873,7 +1873,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -149,7 +149,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Image (Z-Image-Turbo)",
|
||||
"name": "Text to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1054,7 +1054,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1075,4 +1076,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -4286,7 +4286,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1572,7 +1572,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1586,4 +1587,4 @@
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -434,8 +434,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -307,7 +307,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Video Captioning"
|
||||
"category": "Text generation/Video Captioning",
|
||||
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Video Inpaint(Wan2.1 VACE)",
|
||||
"name": "Video Inpaint (Wan 2.1 VACE)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2368,7 +2368,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Inpaint video"
|
||||
"category": "Video generation and editing/Inpaint video",
|
||||
"description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -584,8 +584,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video Tools/Stitch videos"
|
||||
"category": "Video Tools/Stitch videos",
|
||||
"description": "Stitches multiple video clips into a single sequential video file."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -412,9 +412,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Enhance video"
|
||||
"category": "Video generation and editing/Enhance video",
|
||||
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
7
comfy/background_removal/birefnet.json
Normal file
7
comfy/background_removal/birefnet.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"model_type": "birefnet",
|
||||
"image_std": [1.0, 1.0, 1.0],
|
||||
"image_mean": [0.0, 0.0, 0.0],
|
||||
"image_size": 1024,
|
||||
"resize_to_original": true
|
||||
}
|
||||
689
comfy/background_removal/birefnet.py
Normal file
689
comfy/background_removal/birefnet.py
Normal file
@ -0,0 +1,689 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops import deform_conv2d
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
x = optimized_attention(
|
||||
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
|
||||
).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
|
||||
self.act = nn.GELU()
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
|
||||
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
x_windows = window_partition(shifted_x, self.window_size)
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask)
|
||||
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
def __init__(self, dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
patch_size = (patch_size, patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.layers.append(layer)
|
||||
|
||||
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
|
||||
outs = []
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
class DeformableConv2d(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False, device=None, dtype=None, operations=None):
|
||||
|
||||
super(DeformableConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
|
||||
self.stride = stride if type(stride) is tuple else (stride, stride)
|
||||
self.padding = padding
|
||||
|
||||
self.offset_conv = operations.Conv2d(in_channels,
|
||||
2 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.modulator_conv = operations.Conv2d(in_channels,
|
||||
1 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.regular_conv = operations.Conv2d(in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
offset = self.offset_conv(x)
|
||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||
|
||||
x = deform_conv2d(
|
||||
input=x,
|
||||
offset=offset,
|
||||
weight=weight,
|
||||
bias=None,
|
||||
padding=self.padding,
|
||||
mask=modulator,
|
||||
stride=self.stride,
|
||||
)
|
||||
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||
return x
|
||||
|
||||
class BasicDecBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicDecBlk, self).__init__()
|
||||
inter_channels = 64
|
||||
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.relu_in = nn.ReLU(inplace=True)
|
||||
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
|
||||
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.bn_in(x)
|
||||
x = self.relu_in(x)
|
||||
x = self.dec_att(x)
|
||||
x = self.conv_out(x)
|
||||
x = self.bn_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicLatBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicLatBlk, self).__init__()
|
||||
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ASPPModuleDeformable(nn.Module):
|
||||
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
|
||||
super(_ASPPModuleDeformable, self).__init__()
|
||||
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
|
||||
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
|
||||
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
x = self.bn(x)
|
||||
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class ASPPDeformable(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
|
||||
super(ASPPDeformable, self).__init__()
|
||||
self.down_scale = 1
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channelster = 256 // self.down_scale
|
||||
|
||||
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
|
||||
self.aspp_deforms = nn.ModuleList([
|
||||
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
|
||||
for conv_size in parallel_block_sizes
|
||||
])
|
||||
|
||||
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
||||
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
|
||||
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=True))
|
||||
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
|
||||
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
||||
x5 = self.global_avg_pool(x)
|
||||
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
||||
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
class BiRefNet(nn.Module):
|
||||
def __init__(self, config=None, dtype=None, device=None, operations=None):
|
||||
super(BiRefNet, self).__init__()
|
||||
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
channels = [1536, 768, 384, 192]
|
||||
channels = [c * 2 for c in channels]
|
||||
self.cxt = channels[1:][::-1][-3:]
|
||||
self.squeeze_module = nn.Sequential(*[
|
||||
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(1)
|
||||
])
|
||||
|
||||
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_enc(self, x):
|
||||
x1, x2, x3, x4 = self.bb(x)
|
||||
B, C, H, W = x.shape
|
||||
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
||||
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat(
|
||||
(
|
||||
*[
|
||||
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
][-len(CXT):],
|
||||
x4
|
||||
),
|
||||
dim=1
|
||||
)
|
||||
return (x1, x2, x3, x4)
|
||||
|
||||
def forward_ori(self, x):
|
||||
(x1, x2, x3, x4) = self.forward_enc(x)
|
||||
x4 = self.squeeze_module(x4)
|
||||
features = [x, x1, x2, x3, x4]
|
||||
scaled_preds = self.decoder(features)
|
||||
return scaled_preds
|
||||
|
||||
def forward(self, pixel_values, intermediate_output=None):
|
||||
scaled_preds = self.forward_ori(pixel_values)
|
||||
return scaled_preds
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, channels, device, dtype, operations):
|
||||
super(Decoder, self).__init__()
|
||||
# factory kwargs
|
||||
fk = {"device":device, "dtype":dtype, "operations":operations}
|
||||
DecoderBlock = partial(BasicDecBlk, **fk)
|
||||
LateralBlock = partial(BasicLatBlk, **fk)
|
||||
DBlock = partial(SimpleConvs, **fk)
|
||||
|
||||
self.split = True
|
||||
N_dec_ipt = 64
|
||||
ic = 64
|
||||
ipt_cha_opt = 1
|
||||
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
|
||||
|
||||
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
|
||||
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
|
||||
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
|
||||
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
|
||||
|
||||
fk = {"device":device, "dtype":dtype}
|
||||
|
||||
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
|
||||
|
||||
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
||||
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
||||
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
||||
|
||||
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
|
||||
|
||||
_N = 16
|
||||
|
||||
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
|
||||
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
|
||||
def get_patches_batch(self, x, p):
|
||||
_size_h, _size_w = p.shape[2:]
|
||||
patches_batch = []
|
||||
for idx in range(x.shape[0]):
|
||||
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
|
||||
patches_x = []
|
||||
for column_x in columns_x:
|
||||
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
|
||||
patch_sample = torch.cat(patches_x, dim=1)
|
||||
patches_batch.append(patch_sample)
|
||||
return torch.cat(patches_batch, dim=0)
|
||||
|
||||
def forward(self, features):
|
||||
x, x1, x2, x3, x4 = features
|
||||
|
||||
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
||||
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p4 = self.decoder_block4(x4)
|
||||
p4_gdt = self.gdt_convs_4(p4)
|
||||
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
|
||||
p4 = p4 * gdt_attn_4
|
||||
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p3 = _p4 + self.lateral_block4(x3)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
|
||||
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p3 = self.decoder_block3(_p3)
|
||||
|
||||
p3_gdt = self.gdt_convs_3(p3)
|
||||
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
||||
p3 = p3 * gdt_attn_3
|
||||
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p2 = _p3 + self.lateral_block3(x2)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
|
||||
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p2 = self.decoder_block2(_p2)
|
||||
|
||||
p2_gdt = self.gdt_convs_2(p2)
|
||||
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
||||
p2 = p2 * gdt_attn_2
|
||||
|
||||
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p1 = _p2 + self.lateral_block2(x1)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
_p1 = self.decoder_block1(_p1)
|
||||
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p1_out = self.conv_out1(_p1)
|
||||
return p1_out
|
||||
|
||||
|
||||
class SimpleConvs(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_out(self.conv1(x))
|
||||
78
comfy/bg_removal_model.py
Normal file
78
comfy/bg_removal_model.py
Normal file
@ -0,0 +1,78 @@
|
||||
from .utils import load_torch_file
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.clip_model
|
||||
import comfy.background_removal.birefnet
|
||||
|
||||
BG_REMOVAL_MODELS = {
|
||||
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
||||
}
|
||||
|
||||
class BackgroundRemovalModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 1024)
|
||||
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
||||
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
||||
self.model_type = config.get("model_type", "birefnet")
|
||||
self.config = config.copy()
|
||||
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
H, W = image.shape[1], image.shape[2]
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != 1:
|
||||
mask = mask.movedim(-1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_background_removal_model(sd):
|
||||
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
bg_model = BackgroundRemovalModel(json_config)
|
||||
m, u = bg_model.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing background removal: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
sd.pop(k)
|
||||
return bg_model
|
||||
|
||||
def load(ckpt_path):
|
||||
sd = load_torch_file(ckpt_path)
|
||||
return load_background_removal_model(sd)
|
||||
@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
indices = self.index_list
|
||||
anchor_idx = getattr(self, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
indices = [anchor_idx] + list(indices)
|
||||
idx = tuple([slice(None)] * dim + [indices])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
|
||||
skip_count = temporal_offset - 1
|
||||
else:
|
||||
skip_count = temporal_offset
|
||||
|
||||
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
@ -150,7 +161,8 @@ class ContextFuseMethod:
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
if 0 <= anchor_idx < x_in.size(self.dim):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ class Hook:
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
|
||||
@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
|
||||
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
|
||||
if initial_latent is not None:
|
||||
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||
n_init = initial_latent.shape[2]
|
||||
output[:, :, :n_init] = initial_latent
|
||||
|
||||
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame = n_init
|
||||
remaining = lat_t - n_init
|
||||
num_blocks = -(-remaining // num_frame_per_block)
|
||||
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
@ -9,6 +9,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
temporal_downscale_ratio = 1
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -235,6 +236,7 @@ class Flux2(LatentFormat):
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 6
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@ -278,6 +280,7 @@ class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@ -421,6 +424,7 @@ class LTXAV(LTXV):
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
@ -447,6 +451,7 @@ class HunyuanVideo(LatentFormat):
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
@ -472,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
@ -734,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@ -786,8 +793,27 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
scale_factor matches the vae/config.json scaling_factor for the 2b variant.
|
||||
The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
|
||||
use a different value; see CogVideoX1_5 below.
|
||||
"""
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
|
||||
class CogVideoX1_5(CogVideoX):
|
||||
"""Latent format for 5b-class CogVideoX checkpoints.
|
||||
|
||||
Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
|
||||
V1.5-5b family (including VOID inpainting). All of these have
|
||||
scaling_factor=0.7 in their vae/config.json. Auto-selected in
|
||||
supported_models.CogVideoX_T2V based on transformer hidden dim.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.7
|
||||
|
||||
@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
# according to the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
|
||||
@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
|
||||
@ -200,8 +200,13 @@ def pack_masks(masks):
|
||||
|
||||
def unpack_masks(packed):
|
||||
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
||||
shifts = torch.arange(8, device=packed.device)
|
||||
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
||||
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
|
||||
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
|
||||
|
||||
|
||||
def _prep_frame(images, idx, device, dt, size):
|
||||
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
|
||||
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
|
||||
|
||||
|
||||
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
||||
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
|
||||
# SAM3: drop last FPN level
|
||||
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
||||
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track one object, computing backbone per frame to save VRAM."""
|
||||
N = images.shape[0]
|
||||
device, dt = images.device, images.dtype
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
|
||||
mask_input = None
|
||||
if frame_idx == 0:
|
||||
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
||||
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
|
||||
|
||||
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
||||
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
|
||||
target_device=None, target_dtype=None, **kwargs):
|
||||
"""Track one or more objects across video frames.
|
||||
|
||||
Args:
|
||||
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
||||
images: [N, 3, 1008, 1008] video frames
|
||||
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
|
||||
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
||||
pbar: optional progress bar
|
||||
|
||||
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
|
||||
per_object = []
|
||||
for obj_idx in range(N_obj):
|
||||
obj_masks = self._track_single_object(
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
per_object.append(obj_masks)
|
||||
|
||||
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
||||
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
|
||||
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
||||
return []
|
||||
|
||||
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
|
||||
|
||||
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
backbone_obj=None, pbar=None):
|
||||
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
|
||||
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
||||
N, device, dt = images.shape[0], images.device, images.dtype
|
||||
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
|
||||
max_objects = self.INTERNAL_MAX_OBJECTS
|
||||
N = images.shape[0]
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
|
||||
prefetch = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
||||
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
|
||||
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(backbone_stream):
|
||||
next_bb = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
||||
det_masks = torch.empty(0, device=device)
|
||||
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
||||
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = [1.0] * mux_state.total_valid_entries
|
||||
if keep_alive is not None:
|
||||
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
||||
output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
||||
if keep_alive is not None:
|
||||
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
continue
|
||||
else:
|
||||
N_obj = mux_state.total_valid_entries
|
||||
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
if not all_masks or all(m is None for m in all_masks):
|
||||
return {"packed_masks": None, "n_frames": N, "scores": []}
|
||||
|
||||
@ -26,6 +26,7 @@ import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import comfy.float
|
||||
import comfy.hooks
|
||||
@ -1651,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory)
|
||||
in_loop = bool(getattr(tqdm.tqdm, "_instances", None))
|
||||
level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO
|
||||
self._last_prepare_log_key = log_key
|
||||
logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
22
comfy/ops.py
22
comfy/ops.py
@ -562,6 +562,25 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
|
||||
class Conv3d(disable_weight_init.Conv3d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class GroupNorm(disable_weight_init.GroupNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1224,6 +1225,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
|
||||
|
||||
|
||||
@ -1428,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.COGVIDEOX:
|
||||
clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
|
||||
@ -1872,6 +1872,14 @@ class CogVideoX_T2V(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
# 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
|
||||
# 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
|
||||
# Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
|
||||
if unet_config.get("num_attention_heads", 0) >= 48:
|
||||
self.latent_format = latent_formats.CogVideoX1_5
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
@ -1898,6 +1906,20 @@ class CogVideoX_I2V(CogVideoX_T2V):
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CogVideoX_Inpaint(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [
|
||||
LotusD,
|
||||
@ -1978,6 +2000,7 @@ models = [
|
||||
ErnieImage,
|
||||
SAM3,
|
||||
SAM31,
|
||||
CogVideoX_Inpaint,
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
|
||||
@ -1,6 +1,48 @@
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
|
||||
"""Inner T5 tokenizer for CogVideoX.
|
||||
|
||||
CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
|
||||
Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
|
||||
the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
|
||||
"""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
|
||||
|
||||
|
||||
class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
|
||||
|
||||
|
||||
class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
|
||||
"""Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
|
||||
|
||||
Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
|
||||
(which reads self.dtypes) works correctly. The inner model is the standard
|
||||
sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
|
||||
"""
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl",
|
||||
clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
|
||||
model_options=model_options)
|
||||
|
||||
|
||||
def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
"""Factory that returns a CogVideoXT5XXL class configured with the detected
|
||||
T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
|
||||
"""
|
||||
class CogVideoXTEModel_(CogVideoXT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CogVideoXTEModel_
|
||||
|
||||
@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
layer_conf = {"format": "float8_e4m3fn"}
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from spandrel import ImageModelDescriptor
|
||||
from comfy.clip_vision import ClipVisionModel
|
||||
from comfy.clip_vision import Output as ClipVisionOutput_
|
||||
from comfy.bg_removal_model import BackgroundRemovalModel
|
||||
from comfy.controlnet import ControlNet
|
||||
from comfy.hooks import HookGroup, HookKeyframeGroup
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -536,7 +537,6 @@ class Combo(ComfyTypeIO):
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
@ -549,12 +549,14 @@ class MultiCombo(ComfyTypeI):
|
||||
self.default: list[str]
|
||||
|
||||
def as_dict(self):
|
||||
to_return = super().as_dict() | prune_dict({
|
||||
"multi_select": self.multiselect,
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
# Frontend expects `multi_select` to be an object config (not a boolean).
|
||||
# Keep top-level `multiselect` from Combo.Input for backwards compatibility.
|
||||
return super().as_dict() | prune_dict({
|
||||
"multi_select": prune_dict({
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
}),
|
||||
})
|
||||
return to_return
|
||||
|
||||
@comfytype(io_type="IMAGE")
|
||||
class Image(ComfyTypeIO):
|
||||
@ -754,6 +756,11 @@ class Model(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ModelPatcher
|
||||
|
||||
@comfytype(io_type="BACKGROUND_REMOVAL")
|
||||
class BackgroundRemoval(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = BackgroundRemovalModel
|
||||
|
||||
@comfytype(io_type="CLIP_VISION")
|
||||
class ClipVision(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -2399,6 +2406,7 @@ __all__ = [
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"BackgroundRemoval",
|
||||
"AudioEncoder",
|
||||
"AudioEncoderOutput",
|
||||
"StyleModel",
|
||||
|
||||
@ -1282,7 +1282,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_text_inputs(resolutions: list[str]):
|
||||
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1298,6 +1298,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||
default=default_ratio,
|
||||
tooltip="Aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1431,8 +1432,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
@ -1599,9 +1606,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def _seedance2_reference_inputs(resolutions: list[str]):
|
||||
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
*_seedance2_text_inputs(resolutions),
|
||||
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
@ -1679,8 +1686,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
|
||||
@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image,
|
||||
images: Input.Image | list[Input.Image],
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
if image_limit < 0:
|
||||
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||
total_images = get_number_of_images(images)
|
||||
|
||||
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
|
||||
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
|
||||
total_images = sum(get_number_of_images(img) for img in images_list)
|
||||
if total_images <= 0:
|
||||
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||
|
||||
@ -98,10 +101,18 @@ async def create_image_parts(
|
||||
|
||||
# Number of images we'll send as URLs (fileData)
|
||||
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
|
||||
if effective_max > num_url_images:
|
||||
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
|
||||
upload_kwargs = {
|
||||
"wait_label": f"Uploading reference images ({num_url_images}+)",
|
||||
"show_batch_index": False,
|
||||
}
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
images_list,
|
||||
max_images=num_url_images,
|
||||
**upload_kwargs,
|
||||
)
|
||||
for reference_image_url in reference_images_urls:
|
||||
image_parts.append(
|
||||
@ -112,15 +123,22 @@ async def create_image_parts(
|
||||
)
|
||||
)
|
||||
)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(images[idx]),
|
||||
if effective_max > num_url_images:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in images_list:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(flat[idx]),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
# "1:4",
|
||||
# "4:1",
|
||||
# "8:1",
|
||||
# "1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
# "512px",
|
||||
"1K",
|
||||
"2K",
|
||||
"4K",
|
||||
],
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"4:5",
|
||||
"5:4",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
"1:4",
|
||||
"4:1",
|
||||
"8:1",
|
||||
"1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, a 16:9 square is usually generated.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["MINIMAL", "HIGH"],
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 15)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s). Up to 14 images total.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||
"Include any constraints, styles, or details the model should follow.",
|
||||
default="",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE", "IMAGE+TEXT"],
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
images = model.get("images") or {}
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images:
|
||||
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
|
||||
if image_tensors:
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(await create_image_parts(cls, image_tensors))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
image_config = GeminiImageConfig(imageSize=model["resolution"])
|
||||
if model["aspect_ratio"] != "auto":
|
||||
image_config.aspectRatio = model["aspect_ratio"]
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -54,7 +54,12 @@ class GrokImageNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
options=[
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -111,10 +116,12 @@ class GrokImageNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
$rate := widgets.model = "grok-imagine-image-quality"
|
||||
? (widgets.resolution = "1k" ? 0.05 : 0.07)
|
||||
: ($contains(widgets.model, "pro") ? 0.07 : 0.02);
|
||||
{"type":"usd","usd": $rate * widgets.number_of_images}
|
||||
)
|
||||
""",
|
||||
@ -167,7 +174,12 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||
options=[
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.Image.Input("image", display_name="images"),
|
||||
IO.String.Input(
|
||||
@ -228,11 +240,19 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images", "resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
|
||||
$isQualityModel := widgets.model = "grok-imagine-image-quality";
|
||||
$isPro := $contains(widgets.model, "pro");
|
||||
$rate := $isQualityModel
|
||||
? (widgets.resolution = "1k" ? 0.05 : 0.07)
|
||||
: ($isPro ? 0.07 : 0.02);
|
||||
$base := $isQualityModel ? 0.01 : 0.002;
|
||||
$output := $rate * widgets.number_of_images;
|
||||
$isPro
|
||||
? {"type":"usd","usd": $base + $output}
|
||||
: {"type":"range_usd","min_usd": $base + $output, "max_usd": 3 * $base + $output}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"std": 0.07, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
$prices := {
|
||||
"kling-v3": {"std": 0.126, "pro": 0.168},
|
||||
"kling-v2-6": {"std": 0.07, "pro": 0.112}
|
||||
};
|
||||
$modelPrices := $lookup($prices, widgets.model);
|
||||
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
|
||||
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
|
||||
# to correctly detect that Chinese users with working internet do have working internet.
|
||||
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
async def _probe(url: str) -> bool:
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
return resp.status < 500
|
||||
except (ClientError, OSError, asyncio.TimeoutError):
|
||||
return False
|
||||
|
||||
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
|
||||
try:
|
||||
for fut in asyncio.as_completed(probe_tasks):
|
||||
if await fut:
|
||||
results["internet_accessible"] = True
|
||||
break
|
||||
finally:
|
||||
for t in probe_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await asyncio.gather(*probe_tasks, return_exceptions=True)
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
IO.Int.Input("bpm", default=120, min=10, max=300),
|
||||
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
IO.Combo.Input("language", options=['ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id', 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no', 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh', 'unknown'], default='en'),
|
||||
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
|
||||
@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerCFGpp",
|
||||
display_name="SamplerEulerCFG++",
|
||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
||||
category="experimental", # "sampling/custom_sampling/samplers"
|
||||
inputs=[
|
||||
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
||||
],
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -9,6 +10,7 @@ from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
|
||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||
|
||||
|
||||
class ARVideoI2V(io.ComfyNode):
|
||||
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
|
||||
|
||||
VAE-encodes the start image and stores it in the model's transformer_options
|
||||
so that sample_ar_video can seed the KV cache before denoising.
|
||||
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ARVideoI2V",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("start_image"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="MODEL"),
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
|
||||
start_image = comfy.utils.common_upscale(
|
||||
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
initial_latent = vae.encode(start_image[:, :, :, :3])
|
||||
|
||||
m = model.clone()
|
||||
to = m.model_options.setdefault("transformer_options", {})
|
||||
ar_cfg = to.setdefault("ar_config", {})
|
||||
ar_cfg["initial_latent"] = initial_latent
|
||||
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput(m, {"samples": latent})
|
||||
|
||||
|
||||
class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyARVideoLatent,
|
||||
SamplerARVideo,
|
||||
ARVideoI2V,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetSelfAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetCrossAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CLIPAttentionMultiply",
|
||||
search_aliases=["clip attention scale", "text encoder attention"],
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="UNetTemporalAttentionMultiply",
|
||||
category="_for_testing/attention_experiments",
|
||||
category="experimental/attention_experiments",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
|
||||
@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AudioEncoderLoader",
|
||||
display_name="Load Audio Encoder",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
|
||||
60
comfy_extras/nodes_bg_removal.py
Normal file
60
comfy_extras/nodes_bg_removal.py
Normal file
@ -0,0 +1,60 @@
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy.bg_removal_model import load
|
||||
|
||||
|
||||
class LoadBackgroundRemovalModel(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
files = folder_paths.get_filename_list("background_removal")
|
||||
return IO.Schema(
|
||||
node_id="LoadBackgroundRemovalModel",
|
||||
display_name="Load Background Removal Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
|
||||
],
|
||||
outputs=[
|
||||
IO.BackgroundRemoval.Output("bg_model")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, bg_removal_name):
|
||||
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
|
||||
bg = load(path)
|
||||
if bg is None:
|
||||
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
|
||||
return IO.NodeOutput(bg)
|
||||
|
||||
class RemoveBackground(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RemoveBackground",
|
||||
display_name="Remove Background",
|
||||
category="image/background removal",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, image, bg_removal_model):
|
||||
mask = bg_removal_model.encode_image(image)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class BackgroundRemovalExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LoadBackgroundRemovalModel,
|
||||
RemoveBackground
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BackgroundRemovalExtension:
|
||||
return BackgroundRemovalExtension()
|
||||
@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanCameraEmbedding",
|
||||
category="camera",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"camera_pose",
|
||||
|
||||
@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeControlnet",
|
||||
category="_for_testing/conditioning",
|
||||
category="experimental/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="T5TokenizerOptions",
|
||||
category="_for_testing/conditioning",
|
||||
category="experimental/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||
|
||||
@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ContextWindowsManual",
|
||||
display_name="Context Windows (Manual)",
|
||||
category="context",
|
||||
category="model_patches",
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
split_conds_to_windows=split_conds_to_windows,
|
||||
causal_window_fix=causal_window_fix,
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
|
||||
@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="AddNoise",
|
||||
category="_for_testing/custom_sampling/noise",
|
||||
category="experimental/custom_sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ManualSigmas",
|
||||
search_aliases=["custom noise schedule", "define sigmas"],
|
||||
category="_for_testing/custom_sampling",
|
||||
category="experimental/custom_sampling",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||
|
||||
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
||||
node_id="DifferentialDiffusion",
|
||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||
display_name="Differential Diffusion",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
|
||||
@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
PREFERRED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
||||
width = image.shape[2]
|
||||
height = image.shape[1]
|
||||
aspect_ratio = width / height
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
|
||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||
return io.NodeOutput(image)
|
||||
|
||||
|
||||
@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
|
||||
node_id="FreSca",
|
||||
search_aliases=["frequency guidance"],
|
||||
display_name="FreSca",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
description="Applies frequency-dependent scaling to the guidance",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
|
||||
@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15SuperResolution",
|
||||
display_name="Hunyuan Video 1.5 Super Resolution",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanRefinerLatent",
|
||||
display_name="Hunyuan Latent Refiner",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
|
||||
@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
category="conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||
],
|
||||
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||
category="conditioning/video_models",
|
||||
category="conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("front", optional=True),
|
||||
IO.ClipVisionOutput.Input("left", optional=True),
|
||||
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
display_name="Voxel to Mesh (Basic)",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMesh",
|
||||
display_name="Voxel to Mesh",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
|
||||
@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
display_name="Load Hypernetwork",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
|
||||
@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
|
||||
node_id="LoraSave",
|
||||
search_aliases=["export lora"],
|
||||
display_name="Extract and Save Lora",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
inputs=[
|
||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
||||
|
||||
@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
samples = latent["samples"].clone()
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
_, _, _, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask = get_noise_mask(latent)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
@ -594,7 +590,8 @@ class LTXVPreprocess(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVPreprocess",
|
||||
category="image",
|
||||
display_name="LTXV Preprocess",
|
||||
category="video/preprocessors",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input(
|
||||
|
||||
@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Positive-Biased Guidance",
|
||||
category="_for_testing",
|
||||
category="experimental",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
|
||||
@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
|
||||
inverse_mask = torch.ones_like(mask) - mask
|
||||
|
||||
source_portion = mask * source[..., :visible_height, :visible_width]
|
||||
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
|
||||
source_rgb = source[:, :3, :visible_height, :visible_width]
|
||||
dest_slice = destination[..., top:bottom, left:right]
|
||||
|
||||
if destination.shape[1] == 4:
|
||||
if torch.max(dest_slice) == 0:
|
||||
destination[:, :3, top:bottom, left:right] = source_rgb
|
||||
destination[:, 3:4, top:bottom, left:right] = mask
|
||||
else:
|
||||
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
|
||||
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
|
||||
else:
|
||||
source_portion = mask * source_rgb
|
||||
destination_portion = inverse_mask * dest_slice
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
class LatentCompositeMasked(IO.ComfyNode):
|
||||
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
|
||||
display_name="Image Composite Masked",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
IO.Image.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Image.Input("destination", optional=True),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
|
||||
if destination is None: # transparent rgba
|
||||
B, H, W, C = source.shape
|
||||
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
|
||||
if C == 3:
|
||||
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
|
||||
|
||||
expand_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ThresholdMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user