mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Merge branch 'master' into feat/comfykit-awq-w4a16-modulation
This commit is contained in:
commit
2322ff5bf7
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
|
||||
pause
|
||||
31
.github/workflows/openapi-lint.yml
vendored
Normal file
31
.github/workflows/openapi-lint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: OpenAPI Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
- '.spectral.yaml'
|
||||
- '.github/workflows/openapi-lint.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
spectral:
|
||||
name: Run Spectral
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install Spectral
|
||||
run: npm install -g @stoplight/spectral-cli@6
|
||||
|
||||
- name: Lint openapi.yaml
|
||||
run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
|
||||
2
.github/workflows/stable-release.yml
vendored
2
.github/workflows/stable-release.yml
vendored
@ -145,6 +145,8 @@ jobs:
|
||||
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
|
||||
cp ../update_comfyui_and_python_dependencies.bat ./update/
|
||||
|
||||
echo 'local-portable' > ComfyUI/.comfy_environment
|
||||
|
||||
cd ..
|
||||
|
||||
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
|
||||
|
||||
45
.github/workflows/tag-dispatch-cloud.yml
vendored
Normal file
45
.github/workflows/tag-dispatch-cloud.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Tag Dispatch to Cloud
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
dispatch-cloud:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Send repository dispatch to cloud
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}"
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_tag_pushed",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/cloud/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud"
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -23,3 +23,4 @@ web_custom_versions/
|
||||
.DS_Store
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
.comfy_environment
|
||||
|
||||
91
.spectral.yaml
Normal file
91
.spectral.yaml
Normal file
@ -0,0 +1,91 @@
|
||||
extends:
|
||||
- spectral:oas
|
||||
|
||||
# Severity levels: error, warn, info, hint, off
|
||||
# Rules from the built-in "spectral:oas" ruleset are active by default.
|
||||
# Below we tune severity and add custom rules for our conventions.
|
||||
#
|
||||
# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
|
||||
# organization are linted against a single consistent standard.
|
||||
|
||||
rules:
|
||||
# -----------------------------------------------------------------------
|
||||
# Built-in rule severity overrides
|
||||
# -----------------------------------------------------------------------
|
||||
operation-operationId: error
|
||||
operation-description: warn
|
||||
operation-tag-defined: error
|
||||
info-contact: off
|
||||
info-description: warn
|
||||
no-eval-in-markdown: error
|
||||
no-$ref-siblings: error
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: naming conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Property names should be snake_case
|
||||
property-name-snake-case:
|
||||
description: Property names must be snake_case
|
||||
severity: warn
|
||||
given: "$.components.schemas.*.properties[*]~"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
|
||||
|
||||
# Operation IDs should be camelCase
|
||||
operation-id-camel-case:
|
||||
description: Operation IDs must be camelCase
|
||||
severity: warn
|
||||
given: "$.paths.*.*.operationId"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-zA-Z0-9]*$"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: response conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Error responses (4xx, 5xx) should use a consistent shape
|
||||
error-response-schema:
|
||||
description: Error responses should reference a standard error schema
|
||||
severity: hint
|
||||
given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
|
||||
then:
|
||||
field: "$ref"
|
||||
function: truthy
|
||||
|
||||
# All 2xx responses with JSON body should have a schema
|
||||
response-schema-defined:
|
||||
description: Success responses with JSON content should define a schema
|
||||
severity: warn
|
||||
given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
|
||||
then:
|
||||
field: schema
|
||||
function: truthy
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: best practices
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Path parameters must have a description
|
||||
path-param-description:
|
||||
description: Path parameters should have a description
|
||||
severity: warn
|
||||
given:
|
||||
- "$.paths.*.parameters[?(@.in == 'path')]"
|
||||
- "$.paths.*.*.parameters[?(@.in == 'path')]"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
|
||||
# Schemas should have a description
|
||||
schema-description:
|
||||
description: Component schemas should have a description
|
||||
severity: hint
|
||||
given: "$.components.schemas.*"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
@ -1,2 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
|
||||
|
||||
23
README.md
23
README.md
@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
**The most powerful and modular AI engine for content creation.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@ -31,10 +31,16 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
|
||||
<br>
|
||||
</div>
|
||||
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
## Get Started
|
||||
|
||||
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Ernie Image
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
@ -193,13 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
#### All Official Portable Downloads:
|
||||
|
||||
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
|
||||
|
||||
[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ def frontend_install_warning_message():
|
||||
return f"""
|
||||
{get_missing_requirements_message()}
|
||||
|
||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
|
||||
""".strip()
|
||||
|
||||
def parse_version(version: str) -> tuple[int, int, int]:
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
"""Register a node replacement mapping.
|
||||
|
||||
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||
is already registered, the duplicate is ignored. This prevents stale
|
||||
entries from accumulating when custom nodes are reloaded in the same
|
||||
process (e.g. via ComfyUI-Manager).
|
||||
"""
|
||||
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||
for entry in existing:
|
||||
if entry.new_node_id == node_replace.new_node_id:
|
||||
logging.debug(
|
||||
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||
node_replace.old_node_id, node_replace.new_node_id,
|
||||
)
|
||||
return
|
||||
existing.append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
|
||||
@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
"modified": int(os.path.getmtime(path) * 1000),
|
||||
"created": int(os.path.getctime(path) * 1000),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -431,9 +431,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -162,7 +162,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Image (Z-Image-Turbo)",
|
||||
"name": "Canny to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1553,7 +1553,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Canny to image"
|
||||
"category": "Image generation and editing/Canny to image",
|
||||
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1574,4 +1575,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -192,7 +192,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Video (LTX 2.0)",
|
||||
"name": "Canny to Video (LTX 2.0)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -3600,7 +3600,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Canny to video"
|
||||
"category": "Video generation and editing/Canny to video",
|
||||
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -3616,4 +3617,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -377,8 +377,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -596,7 +596,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1129,7 +1129,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -608,7 +608,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1609,7 +1609,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 2×2 grid of four equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -2946,7 +2946,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 3×3 grid of nine equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1579,7 +1579,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
|
||||
},
|
||||
{
|
||||
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
|
||||
@ -2461,7 +2462,8 @@
|
||||
]
|
||||
},
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4233,7 +4233,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Depth to video"
|
||||
"category": "Video generation and editing/Depth to video",
|
||||
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
|
||||
},
|
||||
{
|
||||
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
|
||||
@ -5192,7 +5193,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -450,9 +450,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -580,8 +580,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3350,7 +3350,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/First-Last-Frame to Video"
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -575,8 +575,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -752,8 +752,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -374,7 +374,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -310,7 +310,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Image Captioning"
|
||||
"category": "Text generation/Image Captioning",
|
||||
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -315,8 +315,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2138,7 +2138,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1472,7 +1472,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
|
||||
},
|
||||
{
|
||||
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
|
||||
@ -1821,7 +1822,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1837,4 +1839,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
|
||||
@ -1417,7 +1417,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -132,7 +132,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Edit (Qwen 2511)",
|
||||
"name": "Image Edit (Qwen 2511)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1468,7 +1468,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1489,4 +1490,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1188,7 +1188,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1202,4 +1203,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1548,7 +1548,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
|
||||
},
|
||||
{
|
||||
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
|
||||
@ -1907,7 +1908,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -742,9 +742,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -1919,7 +1919,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Outpaint image"
|
||||
"category": "Image generation and editing/Outpaint image",
|
||||
"description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
|
||||
},
|
||||
{
|
||||
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
|
||||
@ -2278,7 +2279,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
},
|
||||
{
|
||||
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
|
||||
@ -2733,7 +2735,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Scales both image and mask together while preserving alignment for editing workflows."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -141,7 +141,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Upscale(Z-image-Turbo)",
|
||||
"name": "Image Upscale (Z-image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1302,7 +1302,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Enhance"
|
||||
"category": "Image generation and editing/Enhance",
|
||||
"description": "Upscales images to higher resolution using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -99,7 +99,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Depth Map (Lotus)",
|
||||
"name": "Image to Depth Map (Lotus)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -948,7 +948,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -964,4 +965,4 @@
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1586,7 +1586,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Image to layers"
|
||||
"category": "Image generation and editing/Image to layers",
|
||||
"description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -72,7 +72,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Model (Hunyuan3d 2.1)",
|
||||
"name": "Image to 3D Model (Hunyuan3d 2.1)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -765,7 +765,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "3D/Image to 3D Model"
|
||||
"category": "3D/Image to 3D Model",
|
||||
"description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4223,7 +4223,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from a single input image using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -206,7 +206,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Video (Wan 2.2)",
|
||||
"name": "Image to Video (Wan 2.2)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2027,7 +2027,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -134,7 +134,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Pose to Image (Z-Image-Turbo)",
|
||||
"name": "Pose to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1298,7 +1298,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Pose to image"
|
||||
"category": "Image generation and editing/Pose to image",
|
||||
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1319,4 +1320,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -3870,7 +3870,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Pose to video"
|
||||
"category": "Video generation and editing/Pose to video",
|
||||
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -270,9 +270,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Prompt enhance"
|
||||
"category": "Text generation/Prompt enhance",
|
||||
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -302,8 +302,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +222,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Audio (ACE-Step 1.5)",
|
||||
"name": "Text to Audio (ACE-Step 1.5)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1502,7 +1502,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Audio/Music generation"
|
||||
"category": "Audio/Music generation",
|
||||
"description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1518,4 +1519,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1029,7 +1029,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1043,4 +1044,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1023,7 +1023,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1037,4 +1038,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1104,7 +1104,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
|
||||
},
|
||||
{
|
||||
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
|
||||
@ -1458,11 +1459,12 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1941,7 +1941,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1873,7 +1873,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -149,7 +149,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Image (Z-Image-Turbo)",
|
||||
"name": "Text to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1054,7 +1054,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1075,4 +1076,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -4286,7 +4286,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1572,7 +1572,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1586,4 +1587,4 @@
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -434,8 +434,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -307,7 +307,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Video Captioning"
|
||||
"category": "Text generation/Video Captioning",
|
||||
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Video Inpaint(Wan2.1 VACE)",
|
||||
"name": "Video Inpaint (Wan 2.1 VACE)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2368,7 +2368,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Inpaint video"
|
||||
"category": "Video generation and editing/Inpaint video",
|
||||
"description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -584,8 +584,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video Tools/Stitch videos"
|
||||
"category": "Video Tools/Stitch videos",
|
||||
"description": "Stitches multiple video clips into a single sequential video file."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -412,9 +412,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Enhance video"
|
||||
"category": "Video generation and editing/Enhance video",
|
||||
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
7
comfy/background_removal/birefnet.json
Normal file
7
comfy/background_removal/birefnet.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"model_type": "birefnet",
|
||||
"image_std": [1.0, 1.0, 1.0],
|
||||
"image_mean": [0.0, 0.0, 0.0],
|
||||
"image_size": 1024,
|
||||
"resize_to_original": true
|
||||
}
|
||||
689
comfy/background_removal/birefnet.py
Normal file
689
comfy/background_removal/birefnet.py
Normal file
@ -0,0 +1,689 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops import deform_conv2d
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
x = optimized_attention(
|
||||
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
|
||||
).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
|
||||
self.act = nn.GELU()
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
|
||||
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
x_windows = window_partition(shifted_x, self.window_size)
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask)
|
||||
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
def __init__(self, dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
patch_size = (patch_size, patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.layers.append(layer)
|
||||
|
||||
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
|
||||
outs = []
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
class DeformableConv2d(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False, device=None, dtype=None, operations=None):
|
||||
|
||||
super(DeformableConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
|
||||
self.stride = stride if type(stride) is tuple else (stride, stride)
|
||||
self.padding = padding
|
||||
|
||||
self.offset_conv = operations.Conv2d(in_channels,
|
||||
2 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.modulator_conv = operations.Conv2d(in_channels,
|
||||
1 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.regular_conv = operations.Conv2d(in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
offset = self.offset_conv(x)
|
||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||
|
||||
x = deform_conv2d(
|
||||
input=x,
|
||||
offset=offset,
|
||||
weight=weight,
|
||||
bias=None,
|
||||
padding=self.padding,
|
||||
mask=modulator,
|
||||
stride=self.stride,
|
||||
)
|
||||
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||
return x
|
||||
|
||||
class BasicDecBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicDecBlk, self).__init__()
|
||||
inter_channels = 64
|
||||
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.relu_in = nn.ReLU(inplace=True)
|
||||
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
|
||||
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.bn_in(x)
|
||||
x = self.relu_in(x)
|
||||
x = self.dec_att(x)
|
||||
x = self.conv_out(x)
|
||||
x = self.bn_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicLatBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicLatBlk, self).__init__()
|
||||
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ASPPModuleDeformable(nn.Module):
|
||||
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
|
||||
super(_ASPPModuleDeformable, self).__init__()
|
||||
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
|
||||
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
|
||||
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
x = self.bn(x)
|
||||
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class ASPPDeformable(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
|
||||
super(ASPPDeformable, self).__init__()
|
||||
self.down_scale = 1
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channelster = 256 // self.down_scale
|
||||
|
||||
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
|
||||
self.aspp_deforms = nn.ModuleList([
|
||||
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
|
||||
for conv_size in parallel_block_sizes
|
||||
])
|
||||
|
||||
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
||||
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
|
||||
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=True))
|
||||
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
|
||||
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
||||
x5 = self.global_avg_pool(x)
|
||||
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
||||
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
class BiRefNet(nn.Module):
|
||||
def __init__(self, config=None, dtype=None, device=None, operations=None):
|
||||
super(BiRefNet, self).__init__()
|
||||
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
channels = [1536, 768, 384, 192]
|
||||
channels = [c * 2 for c in channels]
|
||||
self.cxt = channels[1:][::-1][-3:]
|
||||
self.squeeze_module = nn.Sequential(*[
|
||||
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(1)
|
||||
])
|
||||
|
||||
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_enc(self, x):
|
||||
x1, x2, x3, x4 = self.bb(x)
|
||||
B, C, H, W = x.shape
|
||||
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
||||
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat(
|
||||
(
|
||||
*[
|
||||
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
][-len(CXT):],
|
||||
x4
|
||||
),
|
||||
dim=1
|
||||
)
|
||||
return (x1, x2, x3, x4)
|
||||
|
||||
def forward_ori(self, x):
|
||||
(x1, x2, x3, x4) = self.forward_enc(x)
|
||||
x4 = self.squeeze_module(x4)
|
||||
features = [x, x1, x2, x3, x4]
|
||||
scaled_preds = self.decoder(features)
|
||||
return scaled_preds
|
||||
|
||||
def forward(self, pixel_values, intermediate_output=None):
|
||||
scaled_preds = self.forward_ori(pixel_values)
|
||||
return scaled_preds
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, channels, device, dtype, operations):
|
||||
super(Decoder, self).__init__()
|
||||
# factory kwargs
|
||||
fk = {"device":device, "dtype":dtype, "operations":operations}
|
||||
DecoderBlock = partial(BasicDecBlk, **fk)
|
||||
LateralBlock = partial(BasicLatBlk, **fk)
|
||||
DBlock = partial(SimpleConvs, **fk)
|
||||
|
||||
self.split = True
|
||||
N_dec_ipt = 64
|
||||
ic = 64
|
||||
ipt_cha_opt = 1
|
||||
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
|
||||
|
||||
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
|
||||
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
|
||||
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
|
||||
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
|
||||
|
||||
fk = {"device":device, "dtype":dtype}
|
||||
|
||||
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
|
||||
|
||||
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
||||
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
||||
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
||||
|
||||
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
|
||||
|
||||
_N = 16
|
||||
|
||||
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
|
||||
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
|
||||
def get_patches_batch(self, x, p):
|
||||
_size_h, _size_w = p.shape[2:]
|
||||
patches_batch = []
|
||||
for idx in range(x.shape[0]):
|
||||
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
|
||||
patches_x = []
|
||||
for column_x in columns_x:
|
||||
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
|
||||
patch_sample = torch.cat(patches_x, dim=1)
|
||||
patches_batch.append(patch_sample)
|
||||
return torch.cat(patches_batch, dim=0)
|
||||
|
||||
def forward(self, features):
|
||||
x, x1, x2, x3, x4 = features
|
||||
|
||||
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
||||
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p4 = self.decoder_block4(x4)
|
||||
p4_gdt = self.gdt_convs_4(p4)
|
||||
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
|
||||
p4 = p4 * gdt_attn_4
|
||||
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p3 = _p4 + self.lateral_block4(x3)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
|
||||
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p3 = self.decoder_block3(_p3)
|
||||
|
||||
p3_gdt = self.gdt_convs_3(p3)
|
||||
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
||||
p3 = p3 * gdt_attn_3
|
||||
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p2 = _p3 + self.lateral_block3(x2)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
|
||||
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p2 = self.decoder_block2(_p2)
|
||||
|
||||
p2_gdt = self.gdt_convs_2(p2)
|
||||
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
||||
p2 = p2 * gdt_attn_2
|
||||
|
||||
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p1 = _p2 + self.lateral_block2(x1)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
_p1 = self.decoder_block1(_p1)
|
||||
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p1_out = self.conv_out1(_p1)
|
||||
return p1_out
|
||||
|
||||
|
||||
class SimpleConvs(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_out(self.conv1(x))
|
||||
78
comfy/bg_removal_model.py
Normal file
78
comfy/bg_removal_model.py
Normal file
@ -0,0 +1,78 @@
|
||||
from .utils import load_torch_file
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.clip_model
|
||||
import comfy.background_removal.birefnet
|
||||
|
||||
BG_REMOVAL_MODELS = {
|
||||
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
||||
}
|
||||
|
||||
class BackgroundRemovalModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 1024)
|
||||
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
||||
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
||||
self.model_type = config.get("model_type", "birefnet")
|
||||
self.config = config.copy()
|
||||
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
H, W = image.shape[1], image.shape[2]
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != 1:
|
||||
mask = mask.movedim(-1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_background_removal_model(sd):
|
||||
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
bg_model = BackgroundRemovalModel(json_config)
|
||||
m, u = bg_model.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing background removal: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
sd.pop(k)
|
||||
return bg_model
|
||||
|
||||
def load(ckpt_path):
|
||||
sd = load_torch_file(ckpt_path)
|
||||
return load_background_removal_model(sd)
|
||||
@ -90,8 +90,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
indices = self.index_list
|
||||
anchor_idx = getattr(self, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
indices = [anchor_idx] + list(indices)
|
||||
idx = tuple([slice(None)] * dim + [indices])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
|
||||
skip_count = temporal_offset - 1
|
||||
else:
|
||||
skip_count = temporal_offset
|
||||
|
||||
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
@ -150,7 +161,8 @@ class ContextFuseMethod:
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
if 0 <= anchor_idx < x_in.size(self.dim):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
34
comfy/deploy_environment.py
Normal file
34
comfy/deploy_environment.py
Normal file
@ -0,0 +1,34 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_DEPLOY_ENV = "local-git"
|
||||
_ENV_FILENAME = ".comfy_environment"
|
||||
|
||||
# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
|
||||
# We deliberately avoid `folder_paths.base_path` here because that is overridden
|
||||
# by the `--base-directory` CLI arg to a user-supplied path, whereas the
|
||||
# `.comfy_environment` marker is written by launchers/installers next to the
|
||||
# ComfyUI install itself.
|
||||
_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_deploy_environment() -> str:
|
||||
env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
|
||||
try:
|
||||
with open(env_file, encoding="utf-8") as f:
|
||||
# Cap the read so a malformed or maliciously crafted file (e.g.
|
||||
# a single huge line with no newline) can't blow up memory.
|
||||
first_line = f.readline(128).strip()
|
||||
value = "".join(c for c in first_line if 32 <= ord(c) < 127)
|
||||
if value:
|
||||
return value
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Failed to read %s: %s", env_file, e)
|
||||
|
||||
return _DEFAULT_DEPLOY_ENV
|
||||
@ -93,7 +93,7 @@ class Hook:
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
|
||||
@ -1810,3 +1810,119 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
num_frame_per_block=1):
|
||||
"""
|
||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||
|
||||
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||
|
||||
All AR-loop parameters are passed via the SamplerARVideo node, not read
|
||||
from the checkpoint or transformer_options.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
model_options = extra_args.get("model_options", {})
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
|
||||
if x.ndim != 5:
|
||||
raise ValueError(
|
||||
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||
)
|
||||
|
||||
inner_model = model.inner_model.inner_model
|
||||
causal_model = inner_model.diffusion_model
|
||||
|
||||
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||
raise TypeError(
|
||||
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||
"does not support this interface — choose a different sampler."
|
||||
)
|
||||
|
||||
seed = extra_args.get("seed", 0)
|
||||
|
||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||
device = x.device
|
||||
model_dtype = inner_model.get_dtype()
|
||||
|
||||
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
|
||||
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
|
||||
if initial_latent is not None:
|
||||
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||
n_init = initial_latent.shape[2]
|
||||
output[:, :, :n_init] = initial_latent
|
||||
|
||||
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame = n_init
|
||||
remaining = lat_t - n_init
|
||||
num_blocks = -(-remaining // num_frame_per_block)
|
||||
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
try:
|
||||
for block_idx in trange(num_blocks, disable=disable):
|
||||
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||
fs, fe = current_start_frame, current_start_frame + bf
|
||||
noisy_input = x[:, :, fs:fe]
|
||||
|
||||
ar_state = {
|
||||
"start_frame": current_start_frame,
|
||||
"kv_caches": kv_caches,
|
||||
"crossattn_caches": crossattn_caches,
|
||||
}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
|
||||
for i in range(num_sigma_steps):
|
||||
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||
|
||||
if callback is not None:
|
||||
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
noisy_input = denoised
|
||||
else:
|
||||
sigma_next = sigmas[i + 1]
|
||||
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||
fresh_noise = torch.randn_like(denoised)
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
|
||||
step_count += 1
|
||||
|
||||
output[:, :, fs:fe] = noisy_input
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame += bf
|
||||
finally:
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
@ -9,6 +9,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
temporal_downscale_ratio = 1
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -224,6 +225,7 @@ class Flux2(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
self.taesd_decoder_name = "taef2_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
@ -234,6 +236,7 @@ class Flux2(LatentFormat):
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 6
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@ -277,6 +280,7 @@ class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@ -420,6 +424,7 @@ class LTXAV(LTXV):
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
@ -446,6 +451,7 @@ class HunyuanVideo(LatentFormat):
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
@ -471,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
@ -733,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@ -783,3 +791,29 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
scale_factor matches the vae/config.json scaling_factor for the 2b variant.
|
||||
The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
|
||||
use a different value; see CogVideoX1_5 below.
|
||||
"""
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
|
||||
class CogVideoX1_5(CogVideoX):
|
||||
"""Latent format for 5b-class CogVideoX checkpoints.
|
||||
|
||||
Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
|
||||
V1.5-5b family (including VOID inpainting). All of these have
|
||||
scaling_factor=0.7 in their vae/config.json. Auto-selected in
|
||||
supported_models.CogVideoX_T2V based on transformer hidden dim.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.7
|
||||
|
||||
0
comfy/ldm/cogvideo/__init__.py
Normal file
0
comfy/ldm/cogvideo/__init__.py
Normal file
573
comfy/ldm/cogvideo/model.py
Normal file
573
comfy/ldm/cogvideo/model.py
Normal file
@ -0,0 +1,573 @@
|
||||
# CogVideoX 3D Transformer - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers CogVideoXTransformer3DModel
|
||||
# Style reference: comfy/ldm/wan/model.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
|
||||
"""Returns (cos, sin) each with shape [seq_len, dim].
|
||||
|
||||
Frequencies are computed at dim//2 resolution then repeat_interleaved
|
||||
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
|
||||
"""
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
|
||||
angles = torch.outer(pos.float(), freqs.float())
|
||||
cos = angles.cos().repeat_interleave(2, dim=-1).float()
|
||||
sin = angles.sin().repeat_interleave(2, dim=-1).float()
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos_sin):
|
||||
"""Apply CogVideoX rotary embedding to query or key tensor.
|
||||
|
||||
x: [B, heads, seq_len, head_dim]
|
||||
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
|
||||
|
||||
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
|
||||
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
|
||||
"""
|
||||
cos, sin = freqs_cos_sin
|
||||
cos = cos[None, None, :, :].to(x.device)
|
||||
sin = sin[None, None, :, :].to(x.device)
|
||||
|
||||
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
|
||||
args = timesteps[:, None].float() * freqs[None] * scale
|
||||
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
if flip_sin_to_cos:
|
||||
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
|
||||
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
|
||||
|
||||
embed_dim_spatial = 2 * (embed_dim // 3)
|
||||
embed_dim_temporal = embed_dim // 3
|
||||
|
||||
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
|
||||
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
|
||||
|
||||
T, H, W = grid_t.shape
|
||||
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
|
||||
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
|
||||
T, H, W = grid_h.shape
|
||||
half_dim = embed_dim // 2
|
||||
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
return torch.cat([pos_h, pos_w], dim=-1)
|
||||
|
||||
|
||||
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
|
||||
half = embed_dim // 2
|
||||
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
|
||||
args = pos.float().reshape(-1)[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if embed_dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
|
||||
text_dim=4096, bias=True, sample_width=90, sample_height=60,
|
||||
sample_frames=49, temporal_compression_ratio=4,
|
||||
max_text_seq_length=226, spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
|
||||
use_learned_positional_embeddings=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.dim = dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
||||
|
||||
if patch_size_t is None:
|
||||
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
|
||||
|
||||
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
|
||||
|
||||
if use_positional_embeddings or use_learned_positional_embeddings:
|
||||
persistent = use_learned_positional_embeddings
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
if self.patch_size_t is not None:
|
||||
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=device,
|
||||
)
|
||||
pos_embedding = pos_embedding.reshape(-1, self.dim)
|
||||
joint_pos_embedding = pos_embedding.new_zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
|
||||
return joint_pos_embedding
|
||||
|
||||
def forward(self, text_embeds, image_embeds):
|
||||
input_dtype = text_embeds.dtype
|
||||
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
|
||||
batch_size, num_frames, channels, height, width = image_embeds.shape
|
||||
|
||||
proj_dtype = self.proj.weight.dtype
|
||||
if self.patch_size_t is None:
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3)
|
||||
image_embeds = image_embeds.flatten(1, 2)
|
||||
else:
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
||||
image_embeds = image_embeds.reshape(
|
||||
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
||||
)
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
||||
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||
|
||||
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
||||
|
||||
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
||||
text_seq_length = text_embeds.shape[1]
|
||||
num_image_patches = image_embeds.shape[1]
|
||||
|
||||
if self.use_learned_positional_embeddings:
|
||||
image_pos = self.pos_embedding[
|
||||
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
|
||||
].to(device=embeds.device, dtype=embeds.dtype)
|
||||
else:
|
||||
image_pos = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(width // self.patch_size, height // self.patch_size),
|
||||
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=embeds.device,
|
||||
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
|
||||
|
||||
# Build joint: zeros for text + sincos for image
|
||||
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
|
||||
joint_pos[:, text_seq_length:] = image_pos
|
||||
embeds = embeds + joint_pos
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb):
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class CogVideoXAdaLayerNorm(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(self.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class CogVideoXBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, head_dim, time_dim,
|
||||
eps=1e-5, ff_inner_dim=None, ff_bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Self-attention (joint text + latent)
|
||||
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Feed-forward (GELU approximate)
|
||||
inner_dim = ff_inner_dim or dim * 4
|
||||
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# Norm & modulate
|
||||
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Joint self-attention
|
||||
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
b, s, _ = qkv_input.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(qkv_input).view(b, s, n, d)
|
||||
k = self.k(qkv_input).view(b, s, n, d)
|
||||
v = self.v(qkv_input)
|
||||
|
||||
q = self.norm_q(q).view(b, s, n, d)
|
||||
k = self.norm_k(k).view(b, s, n, d)
|
||||
|
||||
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
|
||||
if image_rotary_emb is not None:
|
||||
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
|
||||
k_img = k[:, text_seq_length:].transpose(1, 2)
|
||||
q_img = apply_rotary_emb(q_img, image_rotary_emb)
|
||||
k_img = apply_rotary_emb(k_img, image_rotary_emb)
|
||||
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
|
||||
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
|
||||
|
||||
attn_out = optimized_attention(
|
||||
q.reshape(b, s, n * d),
|
||||
k.reshape(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
attn_out = self.attn_out(attn_out)
|
||||
|
||||
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
|
||||
|
||||
# Norm & modulate for FF
|
||||
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Feed-forward (GELU on concatenated text + latent)
|
||||
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(nn.Module):
|
||||
def __init__(self,
|
||||
num_attention_heads=30,
|
||||
attention_head_dim=64,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
time_embed_dim=512,
|
||||
ofs_embed_dim=None,
|
||||
text_embed_dim=4096,
|
||||
num_layers=30,
|
||||
dropout=0.0,
|
||||
attention_bias=True,
|
||||
sample_width=90,
|
||||
sample_height=60,
|
||||
sample_frames=49,
|
||||
patch_size=2,
|
||||
patch_size_t=None,
|
||||
temporal_compression_ratio=4,
|
||||
max_text_seq_length=226,
|
||||
spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0,
|
||||
use_rotary_positional_embeddings=False,
|
||||
use_learned_positional_embeddings=False,
|
||||
patch_bias=True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
dim = num_attention_heads * attention_head_dim
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
in_channels=in_channels,
|
||||
dim=dim,
|
||||
text_dim=text_embed_dim,
|
||||
bias=patch_bias,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
device=device, dtype=torch.float32, operations=operations,
|
||||
)
|
||||
|
||||
# 2. Time embedding
|
||||
self.time_proj_dim = dim
|
||||
self.time_proj_flip = flip_sin_to_cos
|
||||
self.time_proj_shift = freq_shift
|
||||
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# Optional OFS embedding (CogVideoX 1.5 I2V)
|
||||
self.ofs_proj_dim = ofs_embed_dim
|
||||
if ofs_embed_dim:
|
||||
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
self.ofs_embedding_act = nn.SiLU()
|
||||
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.ofs_embedding_linear_1 = None
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
CogVideoXBlock(
|
||||
dim=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
time_dim=time_embed_dim,
|
||||
eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
# 4. Output
|
||||
self.norm_out = CogVideoXAdaLayerNorm(
|
||||
time_dim=time_embed_dim, dim=dim, eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
if patch_size_t is None:
|
||||
output_dim = patch_size * patch_size * out_channels
|
||||
else:
|
||||
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
||||
|
||||
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
|
||||
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
# ComfyUI passes [B, C, T, H, W]
|
||||
batch_size, channels, t, h, w = x.shape
|
||||
|
||||
# Pad to patch size (temporal + spatial), same pattern as WAN
|
||||
p_t = self.patch_size_t if self.patch_size_t is not None else 1
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
|
||||
|
||||
# CogVideoX expects [B, T, C, H, W]
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
batch_size, num_frames, channels, height, width = x.shape
|
||||
|
||||
# Time embedding
|
||||
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
t_emb = t_emb.to(dtype=x.dtype)
|
||||
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
|
||||
|
||||
if self.ofs_embedding_linear_1 is not None and ofs is not None:
|
||||
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
ofs_emb = ofs_emb.to(dtype=x.dtype)
|
||||
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
|
||||
emb = emb + ofs_emb
|
||||
|
||||
# Patch embedding
|
||||
hidden_states = self.patch_embed(context, x)
|
||||
|
||||
text_seq_length = context.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# Rotary embeddings (if used)
|
||||
image_rotary_emb = None
|
||||
if self.use_rotary_positional_embeddings:
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
if self.patch_size_t is None:
|
||||
post_time = num_frames
|
||||
else:
|
||||
post_time = num_frames // self.patch_size_t
|
||||
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# Unpatchify
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
|
||||
if p_t is None:
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
else:
|
||||
output = hidden_states.reshape(
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
# Back to ComfyUI format [B, C, T, H, W] and crop padding
|
||||
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
|
||||
return output
|
||||
|
||||
def _get_rotary_emb(self, h, w, t, device):
|
||||
"""Compute CogVideoX 3D rotary positional embeddings.
|
||||
|
||||
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions
|
||||
are integer arange computed at max_size, then sliced to actual size.
|
||||
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
|
||||
scaled by spatial_interpolation_scale.
|
||||
"""
|
||||
d = self.attention_head_dim
|
||||
dim_t = d // 4
|
||||
dim_h = d // 8 * 3
|
||||
dim_w = d // 8 * 3
|
||||
|
||||
if self.patch_size_t is not None:
|
||||
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
|
||||
# Compute at max(sample_size, actual_size) then slice to actual
|
||||
base_h = self.patch_embed.sample_height // self.patch_size
|
||||
base_w = self.patch_embed.sample_width // self.patch_size
|
||||
max_h = max(base_h, h)
|
||||
max_w = max(base_w, w)
|
||||
|
||||
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
||||
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
else:
|
||||
# CogVideoX 1.0: "linspace" mode with interpolation scale
|
||||
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
|
||||
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
|
||||
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
|
||||
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
|
||||
|
||||
t_cos, t_sin = freqs_t
|
||||
h_cos, h_sin = freqs_h
|
||||
w_cos, w_sin = freqs_w
|
||||
|
||||
# Slice to actual size (for "slice" mode where grids may be larger)
|
||||
t_cos, t_sin = t_cos[:t], t_sin[:t]
|
||||
h_cos, h_sin = h_cos[:h], h_sin[:h]
|
||||
w_cos, w_sin = w_cos[:w], w_sin[:w]
|
||||
|
||||
# Broadcast and concatenate into [T*H*W, head_dim]
|
||||
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
|
||||
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
|
||||
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
|
||||
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
|
||||
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
|
||||
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
|
||||
|
||||
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
|
||||
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
|
||||
return (cos, sin)
|
||||
566
comfy/ldm/cogvideo/vae.py
Normal file
566
comfy/ldm/cogvideo/vae.py
Normal file
@ -0,0 +1,566 @@
|
||||
# CogVideoX VAE - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers AutoencoderKLCogVideoX
|
||||
# Style reference: comfy/ldm/wan/vae.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""Causal 3D convolution with temporal padding.
|
||||
|
||||
Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
|
||||
a single temporal frame and no cache, the 3D conv weight is sliced to act
|
||||
as a 2D conv, avoiding computation on zero-padded temporal dimensions.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * 3
|
||||
|
||||
time_kernel, height_kernel, width_kernel = kernel_size
|
||||
self.time_kernel_size = time_kernel
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
height_pad = (height_kernel - 1) // 2
|
||||
width_pad = (width_kernel - 1) // 2
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
|
||||
|
||||
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=stride, dilation=dilation,
|
||||
padding=(0, height_pad, width_pad),
|
||||
)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
if self.pad_mode == "replicate":
|
||||
x = F.pad(x, self.time_causal_padding, mode="replicate")
|
||||
conv_cache = None
|
||||
else:
|
||||
kernel_t = self.time_kernel_size
|
||||
if kernel_t > 1:
|
||||
if conv_cache is None and x.shape[2] == 1:
|
||||
# Fast path: single frame, no cache. All temporal padding
|
||||
# frames are copies of the input (replicate-style), so the
|
||||
# 3D conv reduces to a 2D conv with summed temporal kernel.
|
||||
w = comfy.ops.cast_to_input(self.conv.weight, x)
|
||||
b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
|
||||
w2d = w.sum(dim=2, keepdim=True)
|
||||
out = F.conv3d(x, w2d, b,
|
||||
self.conv.stride, self.conv.padding,
|
||||
self.conv.dilation, self.conv.groups)
|
||||
return out, None
|
||||
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
|
||||
x = torch.cat(cached + [x], dim=2)
|
||||
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
|
||||
|
||||
out = self.conv(x)
|
||||
return out, conv_cache
|
||||
|
||||
|
||||
def _interpolate_zq(zq, target_size):
|
||||
"""Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
|
||||
t = target_size[0]
|
||||
if t > 1 and t % 2 == 1:
|
||||
z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
|
||||
z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
|
||||
return torch.cat([z_first, z_rest], dim=2)
|
||||
return F.interpolate(zq, size=target_size)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
"""Spatially conditioned normalization."""
|
||||
def __init__(self, f_channels, zq_channels, groups=32):
|
||||
super().__init__()
|
||||
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f, zq, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if zq.shape[-3:] != f.shape[-3:]:
|
||||
zq = _interpolate_zq(zq, f.shape[-3:])
|
||||
|
||||
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
return self.norm_layer(f) * conv_y + conv_b, new_cache
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
"""3D ResNet block with optional spatial norm."""
|
||||
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
|
||||
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if act_fn == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif act_fn == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
residual = x
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None and hasattr(self, "temb_proj"):
|
||||
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return x + residual, new_cache
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
"""3D downsampling with optional temporal compression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
if t % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
"""3D upsampling with optional temporal decompression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
elif x.shape[2] > 1:
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
else:
|
||||
x = x.squeeze(2)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x[:, :, None, :, :]
|
||||
else:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
|
||||
compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.downsamplers is not None:
|
||||
for ds in self.downsamplers:
|
||||
x = ds(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class MidBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels, out_channels=in_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
|
||||
add_upsample=True, compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.upsamplers is not None:
|
||||
for us in self.upsamplers:
|
||||
x = us(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Encoder3D(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=16,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.down_blocks.append(DownBlock3D(
|
||||
in_channels=input_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
add_downsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=block_out_channels[-1], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
for i, block in enumerate(self.down_blocks):
|
||||
key = f"down_block_{i}"
|
||||
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Decoder3D(nn.Module):
|
||||
def __init__(self, in_channels=16, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
reversed_channels = list(reversed(block_out_channels))
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=reversed_channels[0], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList()
|
||||
output_channel = reversed_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_channel = output_channel
|
||||
output_channel = reversed_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.up_blocks.append(UpBlock3D(
|
||||
in_channels=prev_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block + 1,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels,
|
||||
add_upsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, sample, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
key = f"up_block_{i}"
|
||||
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
|
||||
|
||||
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(nn.Module):
|
||||
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
|
||||
|
||||
Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
|
||||
on the full (low-res) tensor, then the expensive spatial-only up_blocks +
|
||||
norm_out + conv_out are processed in small temporal chunks with conv_cache
|
||||
carrying causal state between chunks. This keeps peak VRAM proportional to
|
||||
chunk_size rather than total frame count.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
latent_channels=16, layers_per_block=3,
|
||||
act_fn="silu", eps=1e-6, groups=32,
|
||||
temporal_compression_ratio=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
self.encoder = Encoder3D(
|
||||
in_channels=in_channels, out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
self.decoder = Decoder3D(
|
||||
in_channels=latent_channels, out_channels=out_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
|
||||
self.num_latent_frames_batch_size = 2
|
||||
self.num_sample_frames_batch_size = 8
|
||||
|
||||
def encode(self, x):
|
||||
t = x.shape[2]
|
||||
frame_batch = self.num_sample_frames_batch_size
|
||||
remainder = t % frame_batch
|
||||
conv_cache = None
|
||||
enc = []
|
||||
|
||||
# Process remainder frames first so only the first chunk can have an
|
||||
# odd temporal dimension — where Downsample3D's first-frame-special
|
||||
# handling in temporal compression is actually correct.
|
||||
if remainder > 0:
|
||||
chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
|
||||
for start in range(remainder, t, frame_batch):
|
||||
chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
|
||||
enc = torch.cat(enc, dim=2)
|
||||
mean, _ = enc.chunk(2, dim=1)
|
||||
return mean
|
||||
|
||||
def decode(self, z):
|
||||
return self._decode_rolling(z)
|
||||
|
||||
def _decode_batched(self, z):
|
||||
"""Original batched decode - processes 2 latent frames through full decoder."""
|
||||
t = z.shape[2]
|
||||
frame_batch = self.num_latent_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
dec = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
|
||||
dec.append(chunk.cpu())
|
||||
return torch.cat(dec, dim=2).to(z.device)
|
||||
|
||||
def _decode_rolling(self, z):
|
||||
"""Rolling decode - processes low-res layers on full tensor, then rolls
|
||||
through expensive high-res layers in temporal chunks."""
|
||||
decoder = self.decoder
|
||||
device = z.device
|
||||
|
||||
# Determine which up_blocks have temporal upsample vs spatial-only.
|
||||
# Temporal up_blocks are cheap (low res), spatial-only are expensive.
|
||||
temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
|
||||
split_at = temporal_compress_level # first N up_blocks do temporal upsample
|
||||
|
||||
# Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
|
||||
x, _ = decoder.conv_in(z)
|
||||
x, _ = decoder.mid_block(x, None, z)
|
||||
|
||||
for i in range(split_at):
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
|
||||
# Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
|
||||
remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
|
||||
chunk_size = 4 # pixel frames per chunk through high-res layers
|
||||
t_expanded = x.shape[2]
|
||||
|
||||
if t_expanded <= chunk_size or len(remaining_blocks) == 0:
|
||||
# Small enough to process in one go
|
||||
for i in remaining_blocks:
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
x, _ = decoder.norm_out(x, z)
|
||||
x = decoder.conv_act(x)
|
||||
x, _ = decoder.conv_out(x)
|
||||
return x
|
||||
|
||||
# Expand z temporally once to match Phase 2's time dimension.
|
||||
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
|
||||
# for the old approach of pre-interpolating to every pixel resolution).
|
||||
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
|
||||
|
||||
# Process in temporal chunks, interpolating spatially per-chunk to avoid
|
||||
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
|
||||
dec_out = []
|
||||
conv_caches = {}
|
||||
|
||||
for chunk_start in range(0, t_expanded, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
||||
x_chunk = x[:, :, chunk_start:chunk_end]
|
||||
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
|
||||
z_spatial_cache = {}
|
||||
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
cache_key = f"up_block_{i}"
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
|
||||
z_spatial_cache[hw_key] = z_t_chunk
|
||||
else:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
|
||||
conv_caches[cache_key] = new_cache
|
||||
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
|
||||
conv_caches["norm_out"] = new_cache
|
||||
x_chunk = decoder.conv_act(x_chunk)
|
||||
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
||||
conv_caches["conv_out"] = new_cache
|
||||
|
||||
dec_out.append(x_chunk.cpu())
|
||||
del z_spatial_cache
|
||||
|
||||
del x, z_time_expanded
|
||||
return torch.cat(dec_out, dim=2).to(device)
|
||||
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_prefetch
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
if "scale" in kwargs:
|
||||
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
# according to the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
|
||||
@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
|
||||
@ -200,8 +200,13 @@ def pack_masks(masks):
|
||||
|
||||
def unpack_masks(packed):
|
||||
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
||||
shifts = torch.arange(8, device=packed.device)
|
||||
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
||||
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
|
||||
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
|
||||
|
||||
|
||||
def _prep_frame(images, idx, device, dt, size):
|
||||
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
|
||||
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
|
||||
|
||||
|
||||
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
||||
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
|
||||
# SAM3: drop last FPN level
|
||||
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
||||
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track one object, computing backbone per frame to save VRAM."""
|
||||
N = images.shape[0]
|
||||
device, dt = images.device, images.dtype
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
|
||||
mask_input = None
|
||||
if frame_idx == 0:
|
||||
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
||||
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
|
||||
|
||||
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
||||
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
|
||||
target_device=None, target_dtype=None, **kwargs):
|
||||
"""Track one or more objects across video frames.
|
||||
|
||||
Args:
|
||||
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
||||
images: [N, 3, 1008, 1008] video frames
|
||||
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
|
||||
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
||||
pbar: optional progress bar
|
||||
|
||||
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
|
||||
per_object = []
|
||||
for obj_idx in range(N_obj):
|
||||
obj_masks = self._track_single_object(
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
per_object.append(obj_masks)
|
||||
|
||||
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
||||
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
|
||||
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
||||
return []
|
||||
|
||||
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
|
||||
|
||||
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
backbone_obj=None, pbar=None):
|
||||
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
|
||||
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
||||
N, device, dt = images.shape[0], images.device, images.dtype
|
||||
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
|
||||
max_objects = self.INTERNAL_MAX_OBJECTS
|
||||
N = images.shape[0]
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
|
||||
prefetch = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
||||
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
|
||||
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(backbone_stream):
|
||||
next_bb = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
||||
det_masks = torch.empty(0, device=device)
|
||||
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
||||
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = [1.0] * mux_state.total_valid_entries
|
||||
if keep_alive is not None:
|
||||
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
||||
output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
||||
if keep_alive is not None:
|
||||
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
continue
|
||||
else:
|
||||
N_obj = mux_state.total_valid_entries
|
||||
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
if not all_masks or all(m is None for m in all_masks):
|
||||
return {"packed_masks": None, "n_frames": N, "scores": []}
|
||||
|
||||
276
comfy/ldm/wan/ar_model.py
Normal file
276
comfy/ldm/wan/ar_model.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||
|
||||
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||
The difference is purely in the forward pass: this model processes one temporal
|
||||
block at a time and maintains a KV cache across blocks.
|
||||
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
repeat_e,
|
||||
WanModel,
|
||||
WanAttentionBlock,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CausalWanSelfAttention(nn.Module):
|
||||
"""Self-attention with KV cache support for autoregressive inference."""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
|
||||
if kv_cache is None:
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v.view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"]
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"] = new_end
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
self.self_attn = CausalWanSelfAttention(
|
||||
dim, num_heads, window_size, qk_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
|
||||
# Self-attention with optional KV cache
|
||||
x = x.contiguous()
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# Cross-attention with optional caching
|
||||
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||
x_ca = optimized_attention(
|
||||
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||
heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn.o(x_ca)
|
||||
else:
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if crossattn_cache is not None:
|
||||
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||
crossattn_cache["is_init"] = True
|
||||
|
||||
# FFN
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(WanModel):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
Same weight structure as WanModel -- loads identical state dicts.
|
||||
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__(
|
||||
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||
wan_attn_block_class=CausalWanAttentionBlock,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
"""
|
||||
Forward one temporal block for autoregressive inference.
|
||||
|
||||
Args:
|
||||
x: [B, C, block_frames, H, W] input latent for the current block
|
||||
timestep: [B, block_frames] per-frame timesteps
|
||||
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||
start_frame: temporal frame index for RoPE offset
|
||||
kv_caches: list of per-layer KV cache dicts
|
||||
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||
clip_fea: optional CLIP features for I2V
|
||||
|
||||
Returns:
|
||||
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||
"""
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# Text embedding (reuses crossattn_cache after first block)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
# RoPE for current block's temporal position
|
||||
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
kv_cache=kv_caches[i],
|
||||
crossattn_cache=crossattn_caches[i])
|
||||
|
||||
# Head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": 0,
|
||||
})
|
||||
return caches
|
||||
|
||||
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||
"""Create fresh cross-attention caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({"is_init": False})
|
||||
return caches
|
||||
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"] = 0
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
for cache in crossattn_caches:
|
||||
cache["is_init"] = False
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.dim // self.num_heads
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
ar_state = transformer_options.get("ar_state")
|
||||
if ar_state is not None:
|
||||
bs = x.shape[0]
|
||||
block_frames = x.shape[2]
|
||||
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||
return self.forward_block(
|
||||
x=x, timestep=t_per_frame, context=context,
|
||||
start_frame=ar_state["start_frame"],
|
||||
kv_caches=ar_state["kv_caches"],
|
||||
crossattn_caches=ar_state["crossattn_caches"],
|
||||
clip_fea=clip_fea,
|
||||
)
|
||||
|
||||
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||
time_dim_concat=time_dim_concat,
|
||||
transformer_options=transformer_options, **kwargs)
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@ -342,6 +343,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||
|
||||
if isinstance(model, comfy.model_base.ErnieImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@ -467,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -52,6 +53,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.cogvideo.model
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
@ -81,6 +83,7 @@ class ModelType(Enum):
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
V_PREDICTION_DDPM = 12
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -115,6 +118,8 @@ def model_sampling(model_config, model_type):
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
elif model_type == ModelType.V_PREDICTION_DDPM:
|
||||
c = comfy.model_sampling.V_PREDICTION_DDPM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -210,6 +215,11 @@ class BaseModel(torch.nn.Module):
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["prefetch_dynamic_vbars"] = (
|
||||
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||
)
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
@ -1356,6 +1366,13 @@ class WAN21(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||
self.image_to_video = False
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
@ -1979,3 +1996,59 @@ class ErnieImage(BaseModel):
|
||||
class SAM3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
|
||||
|
||||
class CogVideoX(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
|
||||
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape = list(noise.shape)
|
||||
shape[1] = extra_channels
|
||||
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
if noise.ndim == 5 and image.ndim == 5:
|
||||
if image.shape[-3] < noise.shape[-3]:
|
||||
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
|
||||
elif image.shape[-3] > noise.shape[-3]:
|
||||
image = image[:, :, :noise.shape[-3]]
|
||||
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if image.shape[1] > extra_channels:
|
||||
image = image[:, :extra_channels]
|
||||
elif image.shape[1] < extra_channels:
|
||||
repeats = extra_channels // image.shape[1]
|
||||
remainder = extra_channels % image.shape[1]
|
||||
parts = [image] * repeats
|
||||
if remainder > 0:
|
||||
parts.append(image[:, :remainder])
|
||||
image = torch.cat(parts, dim=1)
|
||||
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
|
||||
if self.diffusion_model.ofs_proj_dim is not None:
|
||||
ofs = kwargs.get("ofs", None)
|
||||
if ofs is None:
|
||||
noise = kwargs.get("noise", None)
|
||||
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
|
||||
out['ofs'] = comfy.conds.CONDRegular(ofs)
|
||||
return out
|
||||
|
||||
@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cogvideox"
|
||||
|
||||
# Extract config from weight shapes
|
||||
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
|
||||
time_embed_dim = norm1_weight.shape[1]
|
||||
dim = norm1_weight.shape[0] // 6
|
||||
|
||||
dit_config["num_attention_heads"] = dim // 64
|
||||
dit_config["attention_head_dim"] = 64
|
||||
dit_config["time_embed_dim"] = time_embed_dim
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Detect in_channels from patch_embed
|
||||
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
|
||||
if patch_proj_key in state_dict_keys:
|
||||
w = state_dict[patch_proj_key]
|
||||
if w.ndim == 4:
|
||||
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
|
||||
dit_config["in_channels"] = w.shape[1]
|
||||
dit_config["patch_size"] = w.shape[2]
|
||||
elif w.ndim == 2:
|
||||
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["patch_size_t"] = 2
|
||||
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
|
||||
|
||||
text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
|
||||
if text_proj_key in state_dict_keys:
|
||||
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
|
||||
|
||||
# Detect OFS embedding
|
||||
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
|
||||
if ofs_key in state_dict_keys:
|
||||
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
|
||||
|
||||
# Detect positional embedding type
|
||||
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
|
||||
if pos_key in state_dict_keys:
|
||||
dit_config["use_learned_positional_embeddings"] = True
|
||||
dit_config["use_rotary_positional_embeddings"] = False
|
||||
else:
|
||||
dit_config["use_learned_positional_embeddings"] = False
|
||||
dit_config["use_rotary_positional_embeddings"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -112,10 +113,6 @@ if args.directml is not None:
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
@ -583,9 +580,6 @@ class LoadedModel:
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
@ -663,6 +657,7 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -726,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models_temp = set()
|
||||
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||
models_temp = {}
|
||||
for m in models:
|
||||
models_temp.add(m)
|
||||
models_temp[m] = None
|
||||
for mm in m.model_patches_models():
|
||||
models_temp.add(mm)
|
||||
models_temp[mm] = None
|
||||
|
||||
models = models_temp
|
||||
models = list(models_temp)
|
||||
models.reverse()
|
||||
|
||||
models_to_load = []
|
||||
|
||||
@ -1181,6 +1178,10 @@ stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1214,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_aimdo_cast_buffer(offload_stream, device):
|
||||
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1580,10 +1594,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1649,10 +1660,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.is_bf16_supported()
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1783,6 +1791,7 @@ def soft_empty_cache(force=False):
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@ -26,11 +26,13 @@ import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
import comfy.float
|
||||
import comfy.hooks
|
||||
import comfy.lora
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
@ -120,9 +122,20 @@ class LowVramPatch:
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
@ -856,7 +869,9 @@ class ModelPatcher:
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
for param, param_value in params.items():
|
||||
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
|
||||
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
|
||||
key = key_param_name_to_key(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
@ -1637,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
|
||||
|
||||
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory)
|
||||
in_loop = bool(getattr(tqdm.tqdm, "_instances", None))
|
||||
level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO
|
||||
self._last_prepare_log_key = log_key
|
||||
logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
|
||||
|
||||
self.model.device = device_to
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
66
comfy/model_prefetch.py
Normal file
66
comfy/model_prefetch.py
Normal file
@ -0,0 +1,66 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def cleanup_prefetched_modules(comfy_modules):
|
||||
for s in comfy_modules:
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
if prefetch is None:
|
||||
continue
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
if prefetch["signature"] is not None:
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
delattr(s, "_prefetch")
|
||||
|
||||
def cleanup_prefetch_queues():
|
||||
global PREFETCH_QUEUES
|
||||
|
||||
for queue in PREFETCH_QUEUES:
|
||||
for entry in queue:
|
||||
if entry is None or not isinstance(entry, tuple):
|
||||
continue
|
||||
_, prefetch_state = entry
|
||||
comfy_modules = prefetch_state[1]
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def prefetch_queue_pop(queue, device, module):
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
consumed = queue.pop(0)
|
||||
if consumed is not None:
|
||||
offload_stream, prefetch_state = consumed
|
||||
if offload_stream is not None:
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
_, comfy_modules = prefetch_state
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
comfy_modules = []
|
||||
for s in prefetch.modules():
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
def make_prefetch_queue(queue, device, transformer_options):
|
||||
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||
or comfy.model_management.NUM_STREAMS == 0
|
||||
or comfy.model_management.is_device_cpu(device)
|
||||
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||
return None
|
||||
|
||||
queue = [None] + queue + [None]
|
||||
PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class V_PREDICTION_DDPM:
|
||||
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
|
||||
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
|
||||
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
|
||||
"""
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
|
||||
301
comfy/ops.py
301
comfy/ops.py
@ -79,37 +79,68 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
def materialize_meta_param(s, param_keys):
|
||||
for param_key in param_keys:
|
||||
param = getattr(s, param_key, None)
|
||||
if param is not None and getattr(param, "is_meta", False):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
# FIXME: add n=1 cache hit fast path
|
||||
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
|
||||
if offload_stream is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||
return
|
||||
|
||||
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = None
|
||||
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||
|
||||
def get_cast_buffer(buffer_size):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
nonlocal cast_buffer_offset
|
||||
|
||||
if buffer_size == 0:
|
||||
return None
|
||||
|
||||
if offload_stream is None:
|
||||
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
prefetch = {
|
||||
"signature": signature,
|
||||
"resident": resident,
|
||||
}
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
s._prefetch = prefetch
|
||||
continue
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
needs_cast = False
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
@ -121,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
needs_cast = True
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
@ -149,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
prefetch["cast_geometry"] = cast_geometry
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
|
||||
if prefetch["resident"]:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = prefetch["xfer_dest"]
|
||||
if prefetch["needs_cast"]:
|
||||
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
if prefetch["signature"] is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
s._v_signature = prefetch["signature"]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
@ -197,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
update_weight = prefetch["signature"] is not None
|
||||
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||
if bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
if prefetch["signature"] is not None:
|
||||
prefetch["resident"] = True
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
@ -222,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
def format_return(result, offloadable):
|
||||
weight, bias, offload_stream = result
|
||||
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
if not prefetched:
|
||||
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||
|
||||
if not prefetched:
|
||||
if getattr(s, "_prefetch")["signature"] is not None:
|
||||
offload_device = device
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
delattr(s, "_prefetch")
|
||||
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -272,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
@ -306,6 +390,12 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _zero_init_parameter(module, name):
|
||||
param = getattr(module, name)
|
||||
device = None if getattr(param, "is_meta", False) else param.device
|
||||
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
|
||||
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
@ -472,6 +562,25 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -659,6 +768,9 @@ class manual_cast(disable_weight_init):
|
||||
class Conv3d(disable_weight_init.Conv3d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class GroupNorm(disable_weight_init.GroupNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
@ -1205,6 +1317,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.float()
|
||||
manually_loaded_keys.append(scale_key)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||
if isinstance(qdata, QuantizedTensor):
|
||||
scale = qdata._params.scale
|
||||
qdata = qdata._qdata
|
||||
else:
|
||||
scale = None
|
||||
|
||||
x = torch.nn.functional.embedding(
|
||||
input, qdata, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||
x = x.to(dtype=target_dtype)
|
||||
if scale is not None and scale != 1.0:
|
||||
x = x * scale.to(dtype=target_dtype)
|
||||
return x
|
||||
|
||||
# Fallback for non-quantized or weight_function (LoRA) case
|
||||
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
|
||||
@ -2,7 +2,6 @@ import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import psutil
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@ -12,11 +11,6 @@ def get_pin(module):
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
|
||||
#we split the difference and assume half the RAM cache headroom is for us
|
||||
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
|
||||
comfy.memory_management.extra_ram_release(ram_headroom)
|
||||
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -27,7 +29,15 @@ try:
|
||||
"other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.",
|
||||
".".join(map(str, cuda_version)))
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.model_management
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
|
||||
@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
|
||||
control_nets = set(cnets)
|
||||
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||
control_nets = list(dict.fromkeys(cnets))
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
|
||||
39
comfy/sd.py
39
comfy/sd.py
@ -18,6 +18,7 @@ import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
@ -64,6 +65,8 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -478,7 +481,10 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
|
||||
self.latent_channels = metadata["tae_latent_channels"]
|
||||
else:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
@ -652,6 +658,17 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
|
||||
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@ -1208,6 +1225,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
|
||||
|
||||
|
||||
@ -1256,6 +1274,9 @@ class TEModel(Enum):
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
MINISTRAL_3_3B = 28
|
||||
GEMMA_4_E4B = 29
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1281,6 +1302,12 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_4_31B
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E2B
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1403,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.COGVIDEOX:
|
||||
clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
@ -1420,6 +1450,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
@ -27,6 +27,7 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1166,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"causal_ar": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.unet_config.pop("causal_ar", None)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.WAN21_CausalAR(self, device=device)
|
||||
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1832,6 +1852,156 @@ class SAM31(SAM3):
|
||||
unet_config = {"image_model": "SAM31"}
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
|
||||
class CogVideoX_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
}
|
||||
|
||||
models += [SVD_img2vid]
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.012,
|
||||
"beta_schedule": "linear",
|
||||
"zsnr": True,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.CogVideoX
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
# 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
|
||||
# 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
|
||||
# Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
|
||||
if unet_config.get("num_attention_heads", 0) >= 48:
|
||||
self.latent_format = latent_formats.CogVideoX1_5
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
|
||||
|
||||
class CogVideoX_I2V(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CogVideoX_Inpaint(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [
|
||||
LotusD,
|
||||
Stable_Zero123,
|
||||
SD15_instructpix2pix,
|
||||
SD15,
|
||||
SD20,
|
||||
SD21UnclipL,
|
||||
SD21UnclipH,
|
||||
SDXL_instructpix2pix,
|
||||
SDXLRefiner,
|
||||
SDXL,
|
||||
SSD1B,
|
||||
KOALA_700M,
|
||||
KOALA_1B,
|
||||
Segmind_Vega,
|
||||
SD_X4Upscaler,
|
||||
Stable_Cascade_C,
|
||||
Stable_Cascade_B,
|
||||
SV3D_u,
|
||||
SV3D_p,
|
||||
SD3,
|
||||
StableAudio,
|
||||
AuraFlow,
|
||||
PixArtAlpha,
|
||||
PixArtSigma,
|
||||
HunyuanDiT,
|
||||
HunyuanDiT1,
|
||||
FluxInpaint,
|
||||
Flux,
|
||||
LongCatImage,
|
||||
FluxSchnell,
|
||||
GenmoMochi,
|
||||
LTXV,
|
||||
LTXAV,
|
||||
HunyuanVideo15_SR_Distilled,
|
||||
HunyuanVideo15,
|
||||
HunyuanImage21Refiner,
|
||||
HunyuanImage21,
|
||||
HunyuanVideoSkyreelsI2V,
|
||||
HunyuanVideoI2V,
|
||||
HunyuanVideo,
|
||||
CosmosT2V,
|
||||
CosmosI2V,
|
||||
CosmosT2IPredict2,
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
WAN21_T2V,
|
||||
WAN21_I2V,
|
||||
WAN21_FunControl2V,
|
||||
WAN21_Vace,
|
||||
WAN21_Camera,
|
||||
WAN22_Camera,
|
||||
WAN22_S2V,
|
||||
WAN21_HuMo,
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
HiDream,
|
||||
Chroma,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
RT_DETR_v4,
|
||||
ErnieImage,
|
||||
SAM3,
|
||||
SAM31,
|
||||
CogVideoX_Inpaint,
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
]
|
||||
|
||||
@ -7,6 +7,7 @@ from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
operations=comfy.ops.disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
@ -47,11 +48,14 @@ class TGrow(nn.Module):
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
|
||||
patch_size=1, decode=False):
|
||||
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
if not decode and patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, patch_size)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
x = x.view(B, T, C, H, W)
|
||||
if decode and patch_size > 1:
|
||||
x = F.pixel_shuffle(x, patch_size)
|
||||
x = x.view(B, x.shape[0] // B, *x.shape[1:])
|
||||
x = x.to(output_device)
|
||||
else:
|
||||
out = []
|
||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
||||
# Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
|
||||
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
|
||||
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.popleft()
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if not decode and patch_size > 1:
|
||||
xt = F.pixel_unshuffle(xt, patch_size)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
if decode and patch_size > 1:
|
||||
xt = F.pixel_shuffle(xt, patch_size)
|
||||
out.append(xt.to(output_device))
|
||||
del xt
|
||||
else:
|
||||
b = model[i]
|
||||
@ -165,24 +176,20 @@ class TAEHV(nn.Module):
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if self.patch_size > 1:
|
||||
B, T, C, H, W = x.shape
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
||||
if x.shape[1] % self.t_downscale != 0:
|
||||
# pad at end to multiple of t_downscale
|
||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
|
||||
patch_size=self.patch_size).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
|
||||
output_device=comfy.model_management.intermediate_device(),
|
||||
patch_size=self.patch_size, decode=True)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
@ -17,32 +17,79 @@ class Clamp(nn.Module):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
def forward(self, x):
|
||||
if not use_midblock_gn:
|
||||
self.pool = None
|
||||
return
|
||||
n_gn = n_in * 4
|
||||
self.pool = nn.Sequential(
|
||||
comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
|
||||
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
|
||||
nn.ReLU(inplace=True),
|
||||
comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pool is not None:
|
||||
x = x + self.pool(x)
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
def Encoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
class Encoder(nn.Sequential):
|
||||
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||
super().__init__(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
|
||||
class Decoder(nn.Sequential):
|
||||
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||
super().__init__(
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class DecoderFlux2(Decoder):
|
||||
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||
if latent_channels != 128 or not use_gn:
|
||||
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||
super().__init__(latent_channels=32, use_gn=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
x = (
|
||||
x
|
||||
.reshape(B, 32, 2, 2, H, W)
|
||||
.permute(0, 1, 4, 2, 5, 3)
|
||||
.reshape(B, 32, H * 2, W * 2)
|
||||
)
|
||||
return super().forward(x)
|
||||
|
||||
class EncoderFlux2(Encoder):
|
||||
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||
if latent_channels != 128 or not use_gn:
|
||||
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||
super().__init__(latent_channels=32, use_gn=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
result = super().forward(x)
|
||||
B, C, H, W = result.shape
|
||||
return (
|
||||
result
|
||||
.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
.permute(0, 1, 3, 5, 2, 4)
|
||||
.reshape(B, 128, H // 2, W // 2)
|
||||
)
|
||||
|
||||
def Decoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
@ -51,8 +98,15 @@ class TAESD(nn.Module):
|
||||
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||
if latent_channels == 128:
|
||||
encoder_class = EncoderFlux2
|
||||
decoder_class = DecoderFlux2
|
||||
else:
|
||||
encoder_class = Encoder
|
||||
decoder_class = Decoder
|
||||
self.taesd_encoder = encoder_class(latent_channels=latent_channels)
|
||||
self.taesd_decoder = decoder_class(latent_channels=latent_channels)
|
||||
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
@ -61,19 +115,19 @@ class TAESD(nn.Module):
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
def scale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
def unscale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
|
||||
48
comfy/text_encoders/cogvideo.py
Normal file
48
comfy/text_encoders/cogvideo.py
Normal file
@ -0,0 +1,48 @@
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
|
||||
"""Inner T5 tokenizer for CogVideoX.
|
||||
|
||||
CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
|
||||
Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
|
||||
the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
|
||||
"""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
|
||||
|
||||
|
||||
class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
|
||||
|
||||
|
||||
class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
|
||||
"""Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
|
||||
|
||||
Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
|
||||
(which reads self.dtypes) works correctly. The inner model is the standard
|
||||
sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
|
||||
"""
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl",
|
||||
clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
|
||||
model_options=model_options)
|
||||
|
||||
|
||||
def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
"""Factory that returns a CogVideoXT5XXL class configured with the detected
|
||||
T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
|
||||
"""
|
||||
class CogVideoXTEModel_(CogVideoXT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CogVideoXTEModel_
|
||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||
super().__init__()
|
||||
ops = ops or nn
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
||||
|
||||
return x, present_key_value
|
||||
|
||||
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||
class ScaledEmbedding(ops.Embedding):
|
||||
def forward(self, input_ids, out_dtype=None):
|
||||
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user