Merge branch 'master' of github.com:comfyanonymous/ComfyUI into merge/0.3.76-snapshot

This commit is contained in:
doctorpangloss 2025-12-09 08:58:52 -08:00
commit a00c902067
177 changed files with 20434 additions and 11564 deletions

View File

@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -8,13 +8,15 @@ body:
Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument.
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes
id: custom-nodes-test
attributes:

View File

@ -0,0 +1,21 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

58
.github/workflows/api-node-template.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@ -1,6 +1,2 @@
# Admins
# From upstream
* @comfyanonymous
* @kosinkadink
# For the fork
* @doctorpangloss
* @comfyanonymous @kosinkadink @guill @doctorpangloss

168
QUANTIZATION.md Normal file
View File

@ -0,0 +1,168 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@ -359,6 +359,8 @@ There are two kinds of custom nodes: vanilla custom nodes, which generally expec
ComfyUI-Manager is a popular extension to help you install and manage other custom nodes. To install it, you will need `git` on your system.
#### Manual Install
The installation process for ComfyUI-Manager requires two steps: installing its Python dependencies, and then cloning its code into the `custom_nodes` directory.
1. **Install dependencies.**
@ -381,6 +383,34 @@ The installation process for ComfyUI-Manager requires two steps: installing its
3. **Restart ComfyUI.**
After the cloning is complete, restart ComfyUI. You should now see a "Manager" button in the menu.
### PyPi Install
[ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
### Setup
1. Install the manager dependencies:
```bash
pip install -r manager_requirements.txt
```
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
```bash
python main.py --enable-manager
```
### Command Line Options
| Flag | Description |
|------|-------------|
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
### Vanilla Custom Nodes
Clone the repository containing the custom nodes into `custom_nodes/` in your working directory and install its requirements, or use the manager.

View File

@ -1,6 +1,6 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.66"
__version__ = "0.3.76"
# This deals with workspace issues
from comfy_compatibility.workspace import auto_patch_workspace_and_restart

View File

@ -11,13 +11,16 @@ import importlib.metadata
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
from typing import Dict, TypedDict, Optional
from aiohttp import web
from importlib.metadata import version
import requests
from typing_extensions import NotRequired
from ..cli_args import DEFAULT_VERSION_STRING
from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-error
logger = logging.getLogger(__name__)
REQUEST_TIMEOUT = 10 # seconds
@ -172,7 +175,53 @@ class FrontendManager:
return ""
@classmethod
def templates_path(cls) -> str:
def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
try:
import comfyui_workflow_templates
@ -299,3 +348,18 @@ class FrontendManager:
logger.info("Falling back to the default frontend.")
check_frontend_version()
return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template

View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@ -60,6 +60,9 @@ class UserManager():
user = "default"
if args.multi_user and "comfy-user" in request.headers:
user = request.headers["comfy-user"]
# Block System Users (use same error message to prevent probing)
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise KeyError("Unknown user: " + user)
if user not in self.users:
raise KeyError("Unknown user: " + user)
@ -67,15 +70,16 @@ class UserManager():
return user
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
user_directory = folder_paths.get_user_directory()
if type == "userdata":
root_dir = user_directory
root_dir = folder_paths.get_user_directory()
else:
raise KeyError("Unknown filepath type:" + type)
user = self.get_request_user_id(request)
path = user_root = os.path.abspath(os.path.join(root_dir, user))
user_root = folder_paths.get_public_user_directory(user)
if user_root is None:
return None
path = user_root
# prevent leaving /{type}
if os.path.commonpath((root_dir, user_root)) != root_dir:
@ -102,7 +106,11 @@ class UserManager():
name = name.strip()
if not name:
raise ValueError("username not provided")
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
raise ValueError("System User prefix not allowed")
user_id = user_id + "_" + str(uuid.uuid4())
self.users[user_id] = name
@ -133,7 +141,10 @@ class UserManager():
if username in self.users.values():
return web.json_response({"error": "Duplicate username."}, status=400)
user_id = self.add_user(username)
try:
user_id = self.add_user(username)
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
return web.json_response(user_id)
@routes.get("/userdata")
@ -425,7 +436,7 @@ class UserManager():
return source
dest = get_user_data_path(request, check_exists=False, param="dest")
if not isinstance(source, str):
if not isinstance(dest, str):
return dest
overwrite = request.query.get("overwrite", 'true') != "false"

View File

@ -415,7 +415,8 @@ class ControlNet(nn.Module):
out_middle = []
if self.num_classes is not None:
assert y.shape[0] == x.shape[0], "There may be a mismatch between the ControlNet and Diffusion models being used"
if y is None:
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
emb = emb + self.label_emb(y)
h = x

View File

@ -105,6 +105,7 @@ def _create_parser() -> EnhancedConfigArgParser:
cache_group.add_argument("--cache-classic", action="store_true", help="WARNING: Unused. Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true",
help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -120,6 +121,10 @@ def _create_parser() -> EnhancedConfigArgParser:
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
manager_group = parser.add_mutually_exclusive_group()
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true",
help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
@ -132,7 +137,8 @@ def _create_parser() -> EnhancedConfigArgParser:
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true",
@ -141,6 +147,7 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help=f"Enable some untested and potentially quality deteriorating optimizations. Pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {' '.join([f.value for f in PerformanceFeature])}", default=set())
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
@ -155,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@ -73,6 +73,7 @@ class Configuration(dict):
temp_directory (Optional[str]): Temporary directory for processing.
input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths.
auto_launch (bool): Auto-launch UI in the default browser. Defaults to False.
disable_auto_launch (bool): Disable auto launching the browser.
cuda_device (Optional[int]): CUDA device ID. None means default device.
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
disable_cuda_malloc (bool): Disable cudaMallocAsync.
@ -100,6 +101,7 @@ class Configuration(dict):
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
cache_lru (int): Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.
cache_ram (float): Use RAM pressure caching with the specified headroom threshold.
use_split_cross_attention (bool): Use split cross-attention optimization.
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
@ -147,14 +149,19 @@ class Configuration(dict):
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
log_stdout (bool): Send normal process output to stdout instead of stderr (default)
panic_when (list[str]): List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it.
enable_manager (bool): Enable the ComfyUI-Manager feature.
disable_manager_ui (bool): Disables only the ComfyUI-Manager UI.
enable_manager_legacy_ui (bool): Enables the legacy UI of ComfyUI-Manager.
enable_compress_response_body (bool): Enable compressing response body.
workflows (list[str]): Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in.
disable_pinned_memory (bool): Disable pinned memory use.
fp8_e8m0fnu_unet (bool): Store unet weights in fp8_e8m0fnu.
bf16_text_enc (bool): Store text encoder weights in bf16.
supports_fp8_compute (bool): ComfyUI will act like if the device supports fp8 compute.
cache_classic (bool): WARNING: Unused. Use the old style (aggressive) caching.
cache_none (bool): Reduced RAM/VRAM usage at the expense of executing every node for each run.
async_offload (bool): Use async weight offloading.
async_offload (Optional[int]): Use async weight offloading. An optional argument controls the amount of offload streams.
disable_async_offload (bool): Disable async weight offloading.
force_non_blocking (bool): Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.
default_hashing_function (str): Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.
mmap_torch_files (bool): Use mmap when loading ckpt/pt files.
@ -189,6 +196,7 @@ class Configuration(dict):
self.temp_directory: Optional[str] = None
self.input_directory: Optional[str] = None
self.auto_launch: bool = False
self.disable_auto_launch: bool = False
self.cuda_device: Optional[int] = None
self.cuda_malloc: bool = True
self.disable_cuda_malloc: bool = True
@ -272,13 +280,19 @@ class Configuration(dict):
self.user_directory: Optional[str] = None
self.panic_when: list[str] = []
self.workflows: list[str] = []
self.enable_manager: bool = False
self.disable_manager_ui: bool = False
self.enable_manager_legacy_ui: bool = False
self.disable_pinned_memory: bool = False
self.fp8_e8m0fnu_unet: bool = False
self.bf16_text_enc: bool = False
self.supports_fp8_compute: bool = False
self.cache_classic: bool = False
self.cache_none: bool = False
self.async_offload: bool = False
self.cache_ram: float = 0.0
self.async_offload: Optional[int] = None
self.disable_async_offload: bool = False
self.force_non_blocking: bool = False
self.default_hashing_function: str = 'sha256'
self.mmap_torch_files: bool = False
@ -289,7 +303,7 @@ class Configuration(dict):
self.comfy_api_base: str = "https://api.comfy.org"
self.database_url: str = db_config()
self.default_device: Optional[int] = None
self.block_runtime_package_installation = None
self.block_runtime_package_installation: bool = False
self.enable_eval: Optional[bool] = False
self.enable_video_to_image_fallback: bool = False

View File

@ -71,18 +71,23 @@ def cuda_malloc_supported():
return True
# todo: is this really how we want to get the torch version?
version = ""
try:
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
except:
pass
if not args.cuda_malloc:
try:
version = ""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py")
if os.path.isfile(ver_file):
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
version = module.__version__
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
@ -97,3 +102,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
def get_torch_version_noimport():
return str(version)

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from typing_extensions import NotRequired, TypedDict, NamedTuple
from .main_pre import tracer
import asyncio
@ -21,14 +23,24 @@ from typing import List, Optional, Tuple, Literal
import torch
from opentelemetry.trace import get_current_span, StatusCode, Status
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, \
make_locked_method_func
from comfy_api.latest import io
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io
from comfy_compatibility.vanilla import vanilla_environment_node_execution_hooks
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \
DependencyAwareCache, \
BasicCache
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
from comfy_execution.caching import (
BasicCache,
CacheKeySetID,
CacheKeySetInputSignature,
NullCache,
HierarchicalCache,
LRUCache,
RAMPressureCache,
)
from comfy_execution.graph import (
DynamicPrompt,
ExecutionBlocker,
ExecutionList,
get_input_info,
)
from comfy_execution.graph_types import FrozenTopologicalSort
from comfy_execution.graph_utils import is_link, GraphBuilder
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, \
@ -94,7 +106,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
@ -106,48 +118,58 @@ class IsChangedCache:
return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum):
CLASSIC = 0
LRU = 1
DEPENDENCY_AWARE = 2
NONE = 2
RAM_PRESSURE = 3
class CacheSet:
def __init__(self, cache_type=None, cache_size=None):
if cache_type == CacheType.DEPENDENCY_AWARE:
self.init_dependency_aware_cache()
def __init__(self, cache_type=None, cache_args: Optional[CacheArgs] = None):
if cache_args is None:
cache_args = {}
if cache_type == CacheType.NONE:
self.init_null_cache()
logger.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU:
if cache_size is None:
cache_size = 0
cache_size = cache_args.get("lru", 0)
self.init_lru_cache(cache_size)
logger.info("Using LRU cache")
else:
self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects]
self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP
def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.objects = HierarchicalCache(CacheKeySetID)
# only hold cached items while the decendents have not executed
def init_dependency_aware_cache(self):
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
self.objects = DependencyAwareCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self):
self.outputs = NullCache()
self.objects = NullCache()
def recursive_debug_dump(self):
result = {
"outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
}
return result
@ -155,12 +177,14 @@ class CacheSet:
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None):
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data=None):
if extra_data is None:
extra_data = {}
is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
schema = None
if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@ -177,17 +201,17 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
input_unique_id = input_data[0]
output_index = input_data[1]
if outputs is None:
if execution_list is None:
mark_missing()
continue # This might be a lazily-evaluated input
cached_output = outputs.get(input_unique_id)
if cached_output is None:
cached = execution_list.get_cache(input_unique_id, unique_id)
if cached is None or cached.outputs is None:
mark_missing()
continue
if output_index >= len(cached_output):
if output_index >= len(cached.outputs):
mark_missing()
continue
obj = cached_output[output_index]
obj = cached.outputs[output_index]
input_data_all[x] = obj
elif input_category is not None:
input_data_all[x] = [input_data]
@ -223,7 +247,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
@ -244,12 +269,12 @@ async def resolve_map_node_over_list_results(results):
@tracer.start_as_current_span("Execute Node")
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, execution_list=None, executed=None):
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None, execution_list=None, executed=None):
with context_set_execution_list_and_inputs(FrozenTopologicalSort.from_topological_sort(execution_list) if execution_list is not None else None, frozenset(executed) if executed is not None else None):
return await __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb, hidden_inputs)
return await __async_map_node_over_list(prompt_id=prompt_id, unique_id=unique_id, obj=obj, input_data_all=input_data_all, func=func, allow_interrupt=allow_interrupt, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
span = get_current_span()
class_type = obj.__class__.__name__
span.set_attribute("class_type", class_type)
@ -312,13 +337,16 @@ async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all,
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@ -376,8 +404,8 @@ def merge_result_data(results, obj):
return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, inputs=None, execution_list=None, executed=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None, execution_list=None, executed=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, execution_list=execution_list, executed=executed)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@ -484,7 +512,7 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca
return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple:
async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes, ui_outputs) -> RecursiveExecutionTuple:
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
@ -492,11 +520,15 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type']
class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None:
cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {}
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id)
cached_ui = cached.ui or {}
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id}, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached)
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
input_data_all = None
@ -526,8 +558,8 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
for r in result:
if is_link(r):
source_node, source_output = r[0], r[1]
node_output = caches.outputs.get(source_node)[source_output]
for o in node_output:
node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_cached.outputs[source_output]:
resolved_output.append(o)
else:
@ -535,10 +567,11 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
resolved_outputs.append(tuple(resolved_output))
output_data = merge_result_data(resolved_outputs, class_def)
output_ui = []
del pending_subgraph_results[unique_id]
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id)
@ -553,7 +586,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed)
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, execution_list=execution_list, executed=executed, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], []))
required_inputs = [x for x in required_inputs if isinstance(x, str) and (
@ -587,7 +620,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, inputs=inputs, execution_list=execution_list, executed=executed)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, execution_list=execution_list, executed=executed, v3_data=v3_data)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@ -600,7 +633,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
asyncio.create_task(await_completion())
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
ui_outputs[unique_id] = {
"meta": {
"node_id": unique_id,
"display_node": display_node_id,
@ -608,7 +641,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
"real_node_id": real_node_id,
},
"output": output_ui
})
}
if server.client_id is not None:
server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id},
server.client_id)
@ -622,10 +655,6 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
if new_graph is None:
cached_outputs.append((False, node_outputs))
else:
# Check for conflicts
for node_id in new_graph.keys():
if dynprompt.has_node(node_id):
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
for node_id, node_info in new_graph.items():
new_node_ids.append(node_id)
display_id = node_info.get("override_display_id", unique_id)
@ -646,11 +675,16 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
subcache.clean_unused()
for node_id in new_output_ids:
execution_list.add_node(node_id)
execution_list.cache_link(node_id, unique_id)
for link in new_output_links:
execution_list.add_strong_link(link[0], link[1], unique_id)
pending_subgraph_results[unique_id] = cached_outputs
return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data)
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except interruption.InterruptProcessingException as iex:
logger.info("Processing interrupted")
@ -702,10 +736,17 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i
return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None)
class CacheArgs(TypedDict):
ram: NotRequired[int]
lru: NotRequired[float]
class PromptExecutor:
def __init__(self, server: ExecutorToClientProgress, cache_type: CacheType | Literal[False] = False, cache_size: int | None = None):
def __init__(self, server: ExecutorToClientProgress, cache_type: CacheType | Literal[False] = False, cache_args: Optional[CacheArgs] = None):
self.status_messages = []
self.caches: Optional[CacheSet] = None
self.success = None
self.cache_size = cache_size
self.cache_args = cache_args
self.cache_type = cache_type
self.server = server
self.raise_exceptions = False
@ -714,7 +755,7 @@ class PromptExecutor:
def reset(self):
self.success = True
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = []
def add_message(self, event, data: dict, broadcast: bool):
@ -819,6 +860,7 @@ class PromptExecutor:
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
@ -833,7 +875,7 @@ class PromptExecutor:
break
assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -842,18 +884,16 @@ class PromptExecutor:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False)
ui_outputs = {}
meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids()
for node_id in all_node_ids:
ui_info = self.caches.ui.get(node_id)
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
@ -883,9 +923,6 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
class_type = prompt[unique_id]['class_type']
obj_class = get_nodes().NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
error: ValidationErrorDict
errors = []
valid = True
@ -893,9 +930,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
validate_function_inputs = []
validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal):
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
@ -904,6 +943,8 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
validate_has_kwargs = argspec.varkw is not None
received_types = {}
valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {})))
for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
@ -1085,7 +1126,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@ -1093,7 +1134,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
@ -1320,8 +1361,7 @@ class PromptQueue(AbstractPromptQueue):
self.server.queue_updated()
return copy.deepcopy(item_with_future.queue_tuple), task_id
def task_done(self, item_id: str, outputs: HistoryResultDict,
status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None):
def task_done(self, item_id: str, outputs: HistoryResultDict, status: Optional[ExecutionStatus], error_details: typing.Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None):
history_result = outputs
with self.mutex:
queue_item = self.currently_running.pop(item_id)
@ -1331,16 +1371,14 @@ class PromptQueue(AbstractPromptQueue):
status_dict = None
if status is not None:
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details)
status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=None)
outputs_ = history_result["outputs"]
# Remove sensitive data from extra_data before storing in history
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in prompt[3]:
prompt[3].pop(sensitive_val)
if process_item is not None:
prompt = process_item(prompt)
history_entry: HistoryEntry = {
"prompt": prompt,
"prompt": prompt.queue_tuple if isinstance(prompt, QueueItem) else prompt,
"outputs": copy.deepcopy(outputs_),
}
if status_dict is not None:

View File

@ -24,6 +24,12 @@ _module_properties = create_module_properties()
logger = logging.getLogger(__name__)
# todo: investigate what this is actually trying to do
# System User Protection - Protects system directories from HTTP endpoint access
# System Users are internal-only users that cannot be accessed via HTTP endpoints.
# They use the '__' prefix convention (similar to Python's private member convention).
_SYSTEM_USER_PREFIX = "__"
@_module_properties.getter
def _supported_pt_extensions() -> set[str]:
@ -58,6 +64,65 @@ def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path:
return Path(path).resolve()
def get_system_user_directory(name: str = "system") -> str:
"""
Get the path to a System User directory.
System User directories (prefixed with '__') are only accessible via internal API,
not through HTTP endpoints. Use this for storing system-internal data that
should not be exposed to users.
Args:
name: System user name (e.g., "system", "cache"). Must be alphanumeric
with underscores allowed, but cannot start with underscore.
Returns:
Absolute path to the system user directory.
Raises:
ValueError: If name is empty, invalid, or starts with underscore.
Example:
>>> get_system_user_directory("cache")
'/path/to/user/__cache'
"""
if not name or not isinstance(name, str):
raise ValueError("System user name cannot be empty")
if not name.replace("_", "").isalnum():
raise ValueError(f"Invalid system user name: '{name}'")
if name.startswith("_"):
raise ValueError("System user name should not start with underscore")
return os.path.join(get_user_directory(), f"{_SYSTEM_USER_PREFIX}{name}")
def get_public_user_directory(user_id: str) -> str | None:
"""
Get the path to a Public User directory for HTTP endpoint access.
This function provides structural security by returning None for any
System User (prefixed with '__'). All HTTP endpoints should use this
function instead of directly constructing user paths.
Args:
user_id: User identifier from HTTP request.
Returns:
Absolute path to the user directory, or None if user_id is invalid
or refers to a System User.
Example:
>>> get_public_user_directory("default")
'/path/to/user/default'
>>> get_public_user_directory("__system")
None
"""
if not user_id or not isinstance(user_id, str):
return None
if user_id.startswith(_SYSTEM_USER_PREFIX):
return None
return os.path.join(get_user_directory(), user_id)
def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False, replace_existing=True, base_paths_from_configuration=True):
"""
Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory.
@ -111,6 +176,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["huggingface"], supported_extensions=set()),
ModelPaths(["model_patches"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["audio_encoders"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["latent_upscale_models"], supported_extensions=set(supported_pt_extensions)),
hf_cache_paths,
hf_xet,
]

View File

@ -15,6 +15,7 @@ output_directory: str
temp_directory: str
input_directory: str
supported_pt_extensions: set[str]
extension_mimetypes_cache: dict[str, str]
# Functions
@ -39,7 +40,7 @@ def get_output_directory() -> str: ...
def get_temp_directory() -> str: ...
def get_input_directory() -> str: ...
def get_input_directory(mkdirs: bool = ...) -> str: ...
def get_user_directory() -> str: ...
@ -108,3 +109,9 @@ def filter_files_content_types(files: List[str], content_types: List[Literal["im
def get_input_subfolders() -> list[str]: ...
def get_system_user_directory(name: str = ...) -> str: ...
def get_public_user_directory(user_id: str) -> Optional[str]: ...

View File

@ -15,14 +15,24 @@ from ..component_model.executor_types import UnencodedPreviewImageMessage
from ..execution_context import current_execution_context
from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
from ..taesd.taesd import TAESD
from ..sd import VAE
from ..utils import load_torch_file
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
logger = logging.getLogger(__name__)
def preview_to_image(latent_image) -> Image:
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
def preview_to_image(latent_image, do_scale=True) -> Image.Image:
if do_scale:
latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1
.mul(0xFF) # to 0..255
)
else:
latents_ubyte = (latent_image.clamp(0, 1)
.mul(0xFF) # to 0..255
)
if model_management.directml_device is not None:
latents_ubyte = latents_ubyte.to(dtype=torch.uint8)
latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=model_management.device_supports_non_blocking(latent_image.device))
@ -31,7 +41,7 @@ def preview_to_image(latent_image) -> Image:
class LatentPreviewer:
def decode_latent_to_preview(self, x0) -> Image:
def decode_latent_to_preview(self, x0) -> Image.Image:
raise NotImplementedError
def decode_latent_to_preview_image(self, preview_format, x0) -> UnencodedPreviewImageMessage:
@ -49,14 +59,23 @@ class TAESDPreviewerImpl(LatentPreviewer):
return preview_to_image(x_sample)
class TAEHVPreviewerImpl(TAESDPreviewerImpl):
def decode_latent_to_preview(self, x0):
x_sample = self.taesd.decode(x0[:1, :, :1])[0][0]
return preview_to_image(x_sample, do_scale=False)
class Latent2RGBPreviewer(LatentPreviewer):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None):
def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None):
self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1)
self.latent_rgb_factors_bias = None
if latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu")
self.latent_rgb_factors_reshape = latent_rgb_factors_reshape
def decode_latent_to_preview(self, x0):
if self.latent_rgb_factors_reshape is not None:
x0 = self.latent_rgb_factors_reshape(x0)
self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device)
if self.latent_rgb_factors_bias is not None:
self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device)
@ -91,14 +110,19 @@ def get_previewer(device, latent_format):
if method == LatentPreviewMethod.TAESD:
if taesd_decoder_path:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
if latent_format.taesd_decoder_name in VIDEO_TAES:
taesd = VAE(load_torch_file(taesd_decoder_path))
taesd.first_stage_model.show_progress_bar = False
previewer = TAEHVPreviewerImpl(taesd)
else:
taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device)
previewer = TAESDPreviewerImpl(taesd)
else:
logger.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name))
if previewer is None:
if latent_format.latent_rgb_factors is not None:
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias)
previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape)
return previewer

View File

@ -1,30 +1,29 @@
from .main_pre import tracer
import asyncio
import contextvars
import gc
import logging
import os
import shutil
import sys
import threading
import time
from pathlib import Path
from typing import Optional
from ..cli_args_types import Configuration
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
from ..execution_context import current_execution_context
from comfy.component_model.abstract_prompt_queue import AbstractPromptQueue
from . import hook_breaker_ac10a0
from .extra_model_paths import load_extra_path_config
from .. import model_management
from ..analytics.analytics import initialize_event_tracking
from ..cli_args_types import Configuration
from ..cmd import cuda_malloc
from ..cmd import folder_paths
from ..cmd import server as server_module
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.entrypoints_common import configure_application_paths, executor_from_args
from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp
from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..distributed.server_stub import ServerStub
from ..execution_context import current_execution_context
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes_context import get_nodes
@ -44,22 +43,27 @@ def cuda_malloc_warning():
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
asyncio.run(_prompt_worker(q, server_instance))
def handle_comfyui_manager_unavailable(args: Configuration):
if not args.windows_standalone_build:
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
args.enable_manager = False
async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer):
from ..cmd import execution
from ..component_model import queue_types
from .. import model_management
args = current_execution_context().configuration
cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0:
cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none:
cache_type = execution.CacheType.DEPENDENCY_AWARE
cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={"lru": args.cache_lru, "ram": args.cache_ram})
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0
@ -76,6 +80,15 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
prompt_id = item[1]
server_instance.last_prompt_id = prompt_id
sensitive = item[5]
extra_data = item[3].copy()
for k in sensitive:
extra_data[k] = sensitive[k]
e.execute(item[2], prompt_id, extra_data, item[4])
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
await e.execute_async(item[2], prompt_id, item[3], item[4])
need_gc = True
@ -96,13 +109,16 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
status_str='success' if e.success else 'error',
completed=e.success,
messages=messages),
error_details=error_details)
error_details=error_details,
process_item=remove_sensitive,
)
if server_instance.client_id is not None:
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id},
server_instance.client_id)
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
# Log Time in a more readable way after 10 minutes
if execution_time > 600:
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
@ -133,7 +149,7 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.
hook_breaker_ac10a0.restore_functions()
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
async def run(server_instance, address='', port=8188, call_on_start=None):
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
@ -194,6 +210,23 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
logger.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
# todo: the manager code has to live inside vanilla_node_importing, it has to deal with a git repo already being in custom_nodes
# if args.enable_manager:
# if importlib.util.find_spec("comfyui_manager"):
# import comfyui_manager
#
# if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
# handle_comfyui_manager_unavailable(args)
# else:
# handle_comfyui_manager_unavailable(args)
#
# if args.enable_manager:
# try:
# import comfyui_manager
# comfyui_manager.prestartup()
# except:
# pass
# configure extra model paths earlier
try:
extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml")
@ -224,6 +257,15 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
loop = asyncio.get_event_loop()
server = server_module.PromptServer(loop)
# todo: the manager code has to live inside vanilla_node_importing, it has to deal with a git repo already being in custom_nodes
# if args.enable_manager and not args.disable_manager_ui:
# try:
# import comfyui_manager
# comfyui_manager.start()
# except:
# pass
if args.external_address is not None:
server.external_address = args.external_address
@ -317,8 +359,7 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None):
try:
await server.setup()
await run(server, address=first_listen_addr, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start)
await run(server, address=first_listen_addr, port=args.port, call_on_start=call_on_start)
except (asyncio.CancelledError, KeyboardInterrupt):
logger.debug("Stopped server")
finally:

View File

@ -12,6 +12,8 @@ import socket
import struct
import sys
import traceback
import time
import typing
import urllib
import uuid
@ -39,7 +41,7 @@ from .. import node_helpers
from .. import utils
from ..api_server.routes.internal.internal_routes import InternalRoutes
from ..app.custom_node_manager import CustomNodeManager
from ..app.frontend_management import FrontendManager
from ..app.frontend_management import FrontendManager, parse_version
from ..app.model_manager import ModelFileManager
from ..app.user_manager import UserManager
from ..cli_args import args
@ -70,6 +72,13 @@ class HeuristicPath(NamedTuple):
# Import cache control middleware
from ..middleware.cache_middleware import cache_control
# todo: what is this really trying to do?
LOADED_MODULE_DIRS = {}
# todo: is this really how we want to enable the manager?
if args.enable_manager:
import comfyui_manager
async def send_socket_catch_exception(function, message):
try:
await function(message)
@ -144,7 +153,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
@ -215,9 +224,23 @@ def create_origin_only_middleware():
return origin_only_middleware
def create_block_external_middleware():
@web.middleware
async def block_external_middleware(request: web.Request, handler):
if request.method == "OPTIONS":
# Pre-flight request. Reply successfully:
response = web.Response()
else:
response = await handler(request)
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
return response
return block_external_middleware
class PromptServer(ExecutorToClientProgress):
instance: Optional['PromptServer'] = None
def __init__(self, loop):
# todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized
PromptServer.instance = self
@ -230,6 +253,7 @@ class PromptServer(ExecutorToClientProgress):
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
self.internal_routes = InternalRoutes(self)
# todo: this is probably read by custom nodes elsewhere
self.supports: List[str] = ["custom_nodes_from_web"]
@ -251,6 +275,12 @@ class PromptServer(ExecutorToClientProgress):
else:
middlewares.append(create_origin_only_middleware())
if args.disable_api_nodes:
middlewares.append(create_block_external_middleware())
if args.enable_manager:
middlewares.append(comfyui_manager.create_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app: web.Application = web.Application(client_max_size=max_upload_size,
handler_args={'max_field_size': 16380},
@ -634,7 +664,7 @@ class PromptServer(ExecutorToClientProgress):
system_stats = {
"system": {
"os": os.name,
"os": sys.platform,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": __version__,
@ -746,8 +776,9 @@ class PromptServer(ExecutorToClientProgress):
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
remove_sensitive = lambda queue: [x[:5] for x in queue]
queue_info['queue_running'] = remove_sensitive(current_queue[0])
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
return web.json_response(queue_info)
@routes.post("/prompt")
@ -782,8 +813,13 @@ class PromptServer(ExecutorToClientProgress):
extra_data["client_id"] = json_data["client_id"]
if valid[0]:
outputs_to_execute = valid[2]
sensitive = {}
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
if sensitive_val in extra_data:
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
self.prompt_queue.put(
QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute),
QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive),
completed=None))
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
@ -1112,6 +1148,7 @@ class PromptServer(ExecutorToClientProgress):
self.model_file_manager.add_routes(self.routes)
# todo: needs to use module directories
self.custom_node_manager.add_routes(self.routes, self.app, {})
self.subgraph_manager.add_routes(self.routes, LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
@ -1132,11 +1169,31 @@ class PromptServer(ExecutorToClientProgress):
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
self.app.add_routes([web.static('/extensions/' + name, dir, follow_symlinks=True)])
workflow_templates_path = FrontendManager.templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
installed_templates_version = FrontendManager.get_installed_templates_version()
use_legacy_templates = True
if installed_templates_version:
try:
use_legacy_templates = (
parse_version(installed_templates_version)
< parse_version("0.3.0")
)
except Exception as exc:
logging.warning(
"Unable to parse templates version '%s': %s",
installed_templates_version,
exc,
)
if use_legacy_templates:
workflow_templates_path = FrontendManager.legacy_templates_path()
if workflow_templates_path:
self.app.add_routes([
web.static('/templates', workflow_templates_path)
])
else:
handler = FrontendManager.template_asset_handler()
if handler:
self.app.router.add_get("/templates/{path:.*}", handler)
# Serve embedded documentation from the package
embedded_docs_path = FrontendManager.embedded_docs_path()

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import typing
from abc import ABCMeta, abstractmethod
from .executor_types import HistoryResultDict
from .executor_types import HistoryResultDict, ExecutionErrorMessage
from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems
@ -43,10 +43,11 @@ class AbstractPromptQueue(metaclass=ABCMeta):
pass
@abstractmethod
def task_done(self, item_id: str, outputs: HistoryResultDict,
status: typing.Optional[ExecutionStatus]):
def task_done(self, item_id: str, outputs: HistoryResultDict, status: typing.Optional[ExecutionStatus], error_details: typing.Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None):
"""
Signals to the user interface that the task with the specified id is completed
:param error_details:
:param process_item:
:param item_id: the ID of the task that should be marked as completed
:param outputs: an opaque dictionary of outputs
:param status:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import copy
import typing
from enum import Enum
from typing import NamedTuple, Optional, List, Literal, Sequence
from typing import Tuple
@ -10,7 +11,12 @@ from typing_extensions import NotRequired, TypedDict
from .outputs_types import OutputsDict
QueueTuple = Tuple[float, str, dict, dict, list]
if typing.TYPE_CHECKING:
from .executor_types import ExecutionErrorMessage
# todo: migrate this and the tree of objects here to a NamedTuple
# number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive
# todo: sensitive dictionary data is actually a JSON value
QueueTuple = Tuple[float, str, dict, dict, list, Optional[dict[str, str]]]
MAXIMUM_HISTORY_SIZE = 10000
@ -63,6 +69,7 @@ class ExecutionStatusAsDict(TypedDict):
status_str: Literal['success', 'error']
completed: bool
messages: List[str]
error_details: NotRequired[ExecutionErrorMessage]
class Flags(TypedDict, total=False):
@ -98,7 +105,8 @@ class NamedQueueTuple(dict):
prompt_id=queue_tuple[1],
prompt=queue_tuple[2],
extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None,
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None
good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None,
sensitive=queue_tuple[5] if len(queue_tuple) > 5 else None,
)
# Store the original tuple in a slot, making it invisible to json.dumps.
self.queue_tuple = queue_tuple
@ -127,6 +135,12 @@ class NamedQueueTuple(dict):
return self.queue_tuple[4]
return None
@property
def sensitive(self) -> Optional[dict]:
if len(self.queue_tuple) > 5:
return self.queue_tuple[5]
return None
class QueueItem(NamedQueueTuple):
"""

View File

@ -54,26 +54,36 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int = 0):
def __init__(self, index_list: list[int], dim: int = 0, total_frames: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
idx = tuple([slice(None)] * dim + [self.index_list])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
window[idx] = full[idx]
return window.to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = [slice(None)] * dim + [self.index_list]
idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add
return full
def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -101,7 +111,8 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -110,13 +121,18 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logger.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
logger.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True
return False
@ -130,6 +146,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
@ -153,12 +174,19 @@ class IndexListContextHandler(ContextHandlerABC):
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
@ -171,7 +199,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@ -180,7 +208,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@ -256,8 +284,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
@ -293,6 +321,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -552,3 +602,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim]
delta = context_length - context_overlap
for start_idx in range(0, latent_video_length - context_length, delta):
place_idx = start_idx + context_length
actual_delta = min(delta, latent_video_length - place_idx)
if actual_delta <= 0:
break
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
noise[tuple(target_slice)] = noise[tuple(source_slice)]
return noise

View File

@ -316,11 +316,13 @@ class ControlLoraOps:
self.bias = None
def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
weight, bias, offload_stream = ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, ops.CastWeightBiasOp):
def __init__(
@ -355,12 +357,13 @@ class ControlLoraOps:
self.down = None
def forward(self, input):
weight, bias = ops.cast_bias_weight(self, input)
weight, bias, offload_stream = ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): # TODO? model_options

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import typing
from ..cmd.main_pre import tracer
import asyncio
@ -22,7 +24,7 @@ from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict, ExecutionErrorMessage
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation, \
ExecutionError
@ -163,7 +165,8 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue):
return item, item[1]
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None):
def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None):
# todo: should we do the process_item? it's clearing sensitive data. but what is the idea? why do things this way, it's crazy
# callee: executed on the worker thread
if "outputs" in outputs:
outputs: HistoryResultDict

View File

@ -7,6 +7,7 @@ class LatentFormat:
latent_dimensions = 2
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
taesd_decoder_name = None
def process_in(self, latent):
@ -191,6 +192,55 @@ class Flux(SD3):
return (latent / self.scale_factor) + self.shift_factor
class Flux2(LatentFormat):
latent_channels = 128
def __init__(self):
self.latent_rgb_factors = [
[0.0058, 0.0113, 0.0073],
[0.0495, 0.0443, 0.0836],
[-0.0099, 0.0096, 0.0644],
[0.2144, 0.3009, 0.3652],
[0.0166, -0.0039, -0.0054],
[0.0157, 0.0103, -0.0160],
[-0.0398, 0.0902, -0.0235],
[-0.0052, 0.0095, 0.0109],
[-0.3527, -0.2712, -0.1666],
[-0.0301, -0.0356, -0.0180],
[-0.0107, 0.0078, 0.0013],
[0.0746, 0.0090, -0.0941],
[0.0156, 0.0169, 0.0070],
[-0.0034, -0.0040, -0.0114],
[0.0032, 0.0181, 0.0080],
[-0.0939, -0.0008, 0.0186],
[0.0018, 0.0043, 0.0104],
[0.0284, 0.0056, -0.0127],
[-0.0024, -0.0022, -0.0030],
[0.1207, -0.0026, 0.0065],
[0.0128, 0.0101, 0.0142],
[0.0137, -0.0072, -0.0007],
[0.0095, 0.0092, -0.0059],
[0.0000, -0.0077, -0.0049],
[-0.0465, -0.0204, -0.0312],
[0.0095, 0.0012, -0.0066],
[0.0290, -0.0034, 0.0025],
[0.0220, 0.0169, -0.0048],
[-0.0332, -0.0457, -0.0468],
[-0.0085, 0.0389, 0.0609],
[-0.0076, 0.0003, -0.0043],
[-0.0111, -0.0460, -0.0614],
]
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
@ -240,210 +290,214 @@ class LTXV(LatentFormat):
def __init__(self):
self.latent_rgb_factors = [
[ 1.1202e-02, -6.3815e-04, -1.0021e-02],
[ 8.6031e-02, 6.5813e-02, 9.5409e-04],
[1.1202e-02, -6.3815e-04, -1.0021e-02],
[8.6031e-02, 6.5813e-02, 9.5409e-04],
[-1.2576e-02, -7.5734e-03, -4.0528e-03],
[ 9.4063e-03, -2.1688e-03, 2.6093e-03],
[ 3.7636e-03, 1.2765e-02, 9.1548e-03],
[ 2.1024e-02, -5.2973e-03, 3.4373e-03],
[9.4063e-03, -2.1688e-03, 2.6093e-03],
[3.7636e-03, 1.2765e-02, 9.1548e-03],
[2.1024e-02, -5.2973e-03, 3.4373e-03],
[-8.8896e-03, -1.9703e-02, -1.8761e-02],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.3160e-02, -1.0523e-02, 1.9709e-03],
[-1.5152e-03, -6.9891e-03, -7.5810e-03],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[ 1.3617e-02, 4.7077e-03, -2.0045e-03],
[ 1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[ 7.3407e-03, 1.5628e-02, 4.4865e-04],
[ 9.5357e-04, -2.9518e-03, -1.4760e-02],
[ 1.9143e-02, 1.0868e-02, 1.2264e-02],
[ 4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[ 3.3902e-02, 3.3405e-02, 3.7454e-02],
[-1.7247e-03, 4.6560e-04, -3.3839e-03],
[1.3617e-02, 4.7077e-03, -2.0045e-03],
[1.0256e-02, 7.7318e-03, 1.3948e-02],
[-1.6108e-02, -6.2151e-03, 1.1561e-03],
[7.3407e-03, 1.5628e-02, 4.4865e-04],
[9.5357e-04, -2.9518e-03, -1.4760e-02],
[1.9143e-02, 1.0868e-02, 1.2264e-02],
[4.4575e-03, 3.6682e-05, -6.8508e-03],
[-4.5681e-04, 3.2570e-03, 7.7929e-03],
[3.3902e-02, 3.3405e-02, 3.7454e-02],
[-2.3001e-02, -2.4877e-03, -3.1033e-03],
[ 5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[5.0265e-02, 3.8841e-02, 3.3539e-02],
[-4.1018e-03, -1.1095e-03, 1.5859e-03],
[-1.2689e-01, -1.3107e-01, -2.1005e-01],
[ 2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[2.6276e-02, 1.4189e-02, -3.5963e-03],
[-4.8679e-03, 8.8486e-03, 7.8029e-03],
[-1.6610e-03, -4.8597e-03, -5.2060e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.1010e-03, 2.3610e-03, 9.3796e-03],
[-2.2482e-02, -2.1305e-02, -1.5087e-02],
[-1.5753e-02, -1.0646e-02, -6.5083e-03],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[ 1.1951e-02, 2.0712e-02, 1.6191e-02],
[-4.6975e-03, 5.0288e-03, -6.7390e-03],
[1.1951e-02, 2.0712e-02, 1.6191e-02],
[-6.3704e-03, -8.4827e-03, -9.5483e-03],
[ 7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[7.2610e-03, -9.9326e-03, -2.2978e-02],
[-9.1904e-04, 6.2882e-03, 9.5720e-03],
[-3.7178e-02, -3.7123e-02, -5.6713e-02],
[-1.3373e-01, -1.0720e-01, -5.3801e-02],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-5.3702e-03, 8.1256e-03, 8.8397e-03],
[-1.5247e-01, -2.1437e-01, -2.1843e-01],
[ 3.1441e-02, 7.0335e-03, -9.7541e-03],
[ 2.1528e-03, -8.9817e-03, -2.1023e-02],
[ 3.8461e-03, -5.8957e-03, -1.5014e-02],
[3.1441e-02, 7.0335e-03, -9.7541e-03],
[2.1528e-03, -8.9817e-03, -2.1023e-02],
[3.8461e-03, -5.8957e-03, -1.5014e-02],
[-4.3470e-03, -1.2940e-02, -1.5972e-02],
[-5.4781e-03, -1.0842e-02, -3.0204e-03],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-6.5347e-03, 3.0806e-03, -1.0163e-02],
[-5.0414e-03, -7.1503e-03, -8.9686e-04],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[ 5.0914e-03, 1.2099e-02, 1.9968e-02],
[ 1.3758e-02, 1.1669e-02, 8.1958e-03],
[-8.5851e-03, -2.4351e-03, 1.0674e-03],
[-9.0016e-03, -9.6493e-03, 1.5692e-03],
[5.0914e-03, 1.2099e-02, 1.9968e-02],
[1.3758e-02, 1.1669e-02, 8.1958e-03],
[-1.0518e-02, -1.1575e-02, -4.1307e-03],
[-2.8410e-02, -3.1266e-02, -2.2149e-02],
[ 2.9336e-03, 3.6511e-02, 1.8717e-02],
[2.9336e-03, 3.6511e-02, 1.8717e-02],
[-1.6703e-02, -1.6696e-02, -4.4529e-03],
[ 4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[ 3.0851e-03, 4.7750e-04, 1.2347e-02],
[4.8818e-02, 4.0063e-02, 8.7410e-03],
[-1.5066e-02, -5.7328e-04, 2.9785e-03],
[-1.7613e-02, -8.1034e-03, 1.3086e-02],
[-9.2633e-03, 1.0803e-02, -6.3489e-03],
[3.0851e-03, 4.7750e-04, 1.2347e-02],
[-2.2785e-02, -2.3043e-02, -2.6005e-02],
[-2.4787e-02, -1.5389e-02, -2.2104e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-2.3572e-02, 1.0544e-03, 1.2361e-02],
[-7.8915e-03, -1.2271e-03, -6.0968e-03],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[ 4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[ 1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[ 2.9476e-02, 2.1424e-02, 3.0424e-02],
[-1.1478e-02, -1.2543e-03, 6.2679e-03],
[-5.4229e-02, 2.6644e-02, 6.3394e-03],
[4.4216e-03, -7.3338e-03, -1.0464e-02],
[-4.5013e-03, 1.6082e-03, 1.4420e-02],
[1.3673e-02, 8.8877e-03, 4.1253e-03],
[-1.0145e-02, 9.0072e-03, 1.5695e-02],
[-5.6234e-03, 1.1847e-03, 8.1261e-03],
[-3.7171e-03, -5.3538e-03, 1.2590e-03],
[2.9476e-02, 2.1424e-02, 3.0424e-02],
[-3.4925e-02, -2.4340e-02, -2.5316e-02],
[-3.4127e-02, -2.2406e-02, -1.0589e-02],
[-1.7342e-02, -1.3249e-02, -1.0719e-02],
[-2.1478e-03, -8.6051e-03, -2.9878e-03],
[ 1.2089e-03, -4.2391e-03, -6.8569e-03],
[ 9.0411e-04, -6.6886e-03, -6.7547e-05],
[ 1.6048e-02, -1.0057e-02, -2.8929e-02],
[ 1.2290e-03, 1.0163e-02, 1.8861e-02],
[ 1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[ 4.6782e-03, -5.2423e-03, 2.4467e-03],
[1.2089e-03, -4.2391e-03, -6.8569e-03],
[9.0411e-04, -6.6886e-03, -6.7547e-05],
[1.6048e-02, -1.0057e-02, -2.8929e-02],
[1.2290e-03, 1.0163e-02, 1.8861e-02],
[1.7264e-02, 2.7257e-04, 1.3785e-02],
[-1.3482e-02, -3.6427e-03, 6.7481e-04],
[4.6782e-03, -5.2423e-03, 2.4467e-03],
[-5.9113e-03, -6.2244e-03, -1.8162e-03],
[ 1.5496e-02, 1.4582e-02, 1.9514e-03],
[ 7.4958e-03, 1.5886e-03, -8.2305e-03],
[ 1.9086e-02, 1.6360e-03, -3.9674e-03],
[1.5496e-02, 1.4582e-02, 1.9514e-03],
[7.4958e-03, 1.5886e-03, -8.2305e-03],
[1.9086e-02, 1.6360e-03, -3.9674e-03],
[-5.7021e-03, -2.7307e-03, -4.1066e-03],
[ 1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[ 1.1632e-02, 8.9193e-03, -7.2813e-03],
[ 7.5721e-03, 2.6784e-03, 1.1393e-02],
[ 5.1939e-03, 3.6903e-03, 1.4049e-02],
[1.7450e-03, 1.4602e-02, 2.5794e-02],
[-8.2788e-04, 2.2902e-03, 4.5161e-03],
[1.1632e-02, 8.9193e-03, -7.2813e-03],
[7.5721e-03, 2.6784e-03, 1.1393e-02],
[5.1939e-03, 3.6903e-03, 1.4049e-02],
[-1.8383e-02, -2.2529e-02, -2.4477e-02],
[ 5.8842e-04, -5.7874e-03, -1.4770e-02],
[5.8842e-04, -5.7874e-03, -1.4770e-02],
[-1.6125e-02, -8.6101e-03, -1.4533e-02],
[ 2.0540e-02, 2.0729e-02, 6.4338e-03],
[ 3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[ 2.8130e-02, 2.3546e-02, 3.2791e-02],
[2.0540e-02, 2.0729e-02, 6.4338e-03],
[3.3587e-03, -1.1226e-02, -1.6444e-02],
[-1.4742e-03, -1.0489e-02, 1.7097e-03],
[2.8130e-02, 2.3546e-02, 3.2791e-02],
[-1.8532e-02, -1.2842e-02, -8.7756e-03],
[-8.0533e-03, -1.0771e-02, -1.7536e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-3.9009e-03, 1.6150e-02, 3.3359e-02],
[-7.4554e-03, -1.4154e-02, -6.1910e-03],
[ 3.4734e-03, -1.1370e-02, -1.0581e-02],
[ 1.1476e-02, 3.9281e-03, 2.8231e-03],
[ 7.1639e-03, -1.4741e-03, -3.8066e-03],
[ 2.2250e-03, -8.7552e-03, -9.5719e-03],
[ 2.4146e-02, 2.1696e-02, 2.8056e-02],
[3.4734e-03, -1.1370e-02, -1.0581e-02],
[1.1476e-02, 3.9281e-03, 2.8231e-03],
[7.1639e-03, -1.4741e-03, -3.8066e-03],
[2.2250e-03, -8.7552e-03, -9.5719e-03],
[2.4146e-02, 2.1696e-02, 2.8056e-02],
[-5.4365e-03, -2.4291e-02, -1.7802e-02],
[ 7.4263e-03, 1.0510e-02, 1.2705e-02],
[ 6.2669e-03, 6.2658e-03, 1.9211e-02],
[ 1.6378e-02, 9.4933e-03, 6.6971e-03],
[ 1.7173e-02, 2.3601e-02, 2.3296e-02],
[7.4263e-03, 1.0510e-02, 1.2705e-02],
[6.2669e-03, 6.2658e-03, 1.9211e-02],
[1.6378e-02, 9.4933e-03, 6.6971e-03],
[1.7173e-02, 2.3601e-02, 2.3296e-02],
[-1.4568e-02, -9.8279e-03, -1.1556e-02],
[ 1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[ 6.1156e-03, 3.4700e-03, -2.6662e-03],
[1.4431e-02, 1.4430e-02, 6.6362e-03],
[-6.8230e-03, 1.8863e-02, 1.4555e-02],
[6.1156e-03, 3.4700e-03, -2.6662e-03],
[-2.6983e-03, -5.9402e-03, -9.2276e-03],
[ 1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[ 2.4222e-03, -4.8039e-03, -1.5759e-02],
[ 2.6244e-02, 2.5951e-02, 2.0249e-02],
[ 1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[ 3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
[1.0235e-02, 7.4173e-03, -7.6243e-03],
[-1.3255e-02, 1.9322e-02, -9.2153e-04],
[2.4222e-03, -4.8039e-03, -1.5759e-02],
[2.6244e-02, 2.5951e-02, 2.0249e-02],
[1.5711e-02, 1.8498e-02, 2.7407e-03],
[-2.1714e-03, 4.7214e-03, -2.2443e-02],
[-7.4747e-03, 7.4166e-03, 1.4430e-02],
[-8.3906e-03, -7.9776e-03, 9.7927e-03],
[3.8321e-02, 9.6622e-03, -1.9268e-02],
[-1.4605e-02, -6.7032e-03, 3.9675e-03]
]
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
[ 0.0696, 0.0795, 0.0518],
[ 0.0135, -0.0945, -0.0282],
[ 0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0395, -0.0331, 0.0445],
[0.0696, 0.0795, 0.0518],
[0.0135, -0.0945, -0.0282],
[0.0108, -0.0250, -0.0765],
[-0.0209, 0.0032, 0.0224],
[-0.0804, -0.0254, -0.0639],
[-0.0991, 0.0271, -0.0669],
[-0.0991, 0.0271, -0.0669],
[-0.0646, -0.0422, -0.0400],
[-0.0696, -0.0595, -0.0894],
[-0.0799, -0.0208, -0.0375],
[ 0.1166, 0.1627, 0.0962],
[ 0.1165, 0.0432, 0.0407],
[0.1166, 0.1627, 0.0962],
[0.1165, 0.0432, 0.0407],
[-0.2315, -0.1920, -0.1355],
[-0.0270, 0.0401, -0.0821],
[-0.0270, 0.0401, -0.0821],
[-0.0616, -0.0997, -0.0727],
[ 0.0249, -0.0469, -0.1703]
[0.0249, -0.0469, -0.1703]
]
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
taesd_decoder_name = "taehv"
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
[0.1817, 0.2284, 0.2423],
[-0.0586, -0.0862, -0.3108],
[-0.4703, -0.4255, -0.3995],
[ 0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[ 0.2511, 0.3098, 0.2787],
[0.0803, 0.1963, 0.1001],
[-0.0820, -0.1050, 0.0400],
[0.2511, 0.3098, 0.2787],
[-0.1830, -0.2117, -0.0040],
[-0.0621, -0.2187, -0.0939],
[ 0.3619, 0.1082, 0.1455],
[ 0.3164, 0.3922, 0.2575],
[ 0.1152, 0.0231, -0.0462],
[0.3619, 0.1082, 0.1455],
[0.3164, 0.3922, 0.2575],
[0.1152, 0.0231, -0.0462],
[-0.1434, -0.3609, -0.3665],
[ 0.0635, 0.1471, 0.1680],
[0.0635, 0.1471, 0.1680],
[-0.3635, -0.1963, -0.3248],
[-0.1865, 0.0365, 0.2346],
[ 0.0447, 0.0994, 0.0881]
[-0.1865, 0.0365, 0.2346],
[0.0447, 0.0994, 0.0881]
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
[ 0.0671, 0.0406, 0.0442],
[ 0.3568, 0.2548, 0.1747],
[ 0.0372, 0.2344, 0.1420],
[ 0.0313, 0.0189, -0.0328],
[ 0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[ 0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[ 0.0680, 0.3019, 0.1128],
[ 0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[ 0.0060, -0.0633, 0.0005],
[ 0.3477, 0.2275, 0.2950],
[ 0.1984, 0.0913, 0.1861]
]
[-0.1299, -0.1692, 0.2932],
[0.0671, 0.0406, 0.0442],
[0.3568, 0.2548, 0.1747],
[0.0372, 0.2344, 0.1420],
[0.0313, 0.0189, -0.0328],
[0.0296, -0.0956, -0.0665],
[-0.3477, -0.4059, -0.2925],
[0.0166, 0.1902, 0.1975],
[-0.0412, 0.0267, -0.1364],
[-0.1293, 0.0740, 0.1636],
[0.0680, 0.3019, 0.1128],
[0.0032, 0.0581, 0.0639],
[-0.1251, 0.0927, 0.1699],
[0.0060, -0.0633, 0.0005],
[0.3477, 0.2275, 0.2950],
[0.1984, 0.0913, 0.1861]
]
latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
@ -458,8 +512,7 @@ class Wan21(LatentFormat):
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
]).view(1, self.latent_channels, 1, 1, 1)
self.taesd_decoder_name = None #TODO
self.taesd_decoder_name = "lighttaew2_1"
def process_in(self, latent):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
@ -471,81 +524,84 @@ class Wan21(LatentFormat):
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class Wan22(Wan21):
latent_channels = 48
latent_dimensions = 3
latent_rgb_factors = [
[ 0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[ 0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[ 0.0656, 0.0851, 0.0808],
[ 0.0264, 0.0463, 0.0912],
[ 0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[ 0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[ 0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[ 0.0564, 0.0264, 0.0561],
[ 0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[ 0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[ 0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[ 0.0390, 0.0670, 0.2863],
[ 0.0069, 0.0144, 0.0082],
[ 0.0006, -0.0167, 0.0079],
[ 0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[ 0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[ 0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[ 0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[ 0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[ 0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[ 0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[ 0.0421, 0.0451, 0.0373],
[ 0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
[0.0119, 0.0103, 0.0046],
[-0.1062, -0.0504, 0.0165],
[0.0140, 0.0409, 0.0491],
[-0.0813, -0.0677, 0.0607],
[0.0656, 0.0851, 0.0808],
[0.0264, 0.0463, 0.0912],
[0.0295, 0.0326, 0.0590],
[-0.0244, -0.0270, 0.0025],
[0.0443, -0.0102, 0.0288],
[-0.0465, -0.0090, -0.0205],
[0.0359, 0.0236, 0.0082],
[-0.0776, 0.0854, 0.1048],
[0.0564, 0.0264, 0.0561],
[0.0006, 0.0594, 0.0418],
[-0.0319, -0.0542, -0.0637],
[-0.0268, 0.0024, 0.0260],
[0.0539, 0.0265, 0.0358],
[-0.0359, -0.0312, -0.0287],
[-0.0285, -0.1032, -0.1237],
[0.1041, 0.0537, 0.0622],
[-0.0086, -0.0374, -0.0051],
[0.0390, 0.0670, 0.2863],
[0.0069, 0.0144, 0.0082],
[0.0006, -0.0167, 0.0079],
[0.0313, -0.0574, -0.0232],
[-0.1454, -0.0902, -0.0481],
[0.0714, 0.0827, 0.0447],
[-0.0304, -0.0574, -0.0196],
[0.0401, 0.0384, 0.0204],
[-0.0758, -0.0297, -0.0014],
[0.0568, 0.1307, 0.1372],
[-0.0055, -0.0310, -0.0380],
[0.0239, -0.0305, 0.0325],
[-0.0663, -0.0673, -0.0140],
[-0.0416, -0.0047, -0.0023],
[0.0166, 0.0112, -0.0093],
[-0.0211, 0.0011, 0.0331],
[0.1833, 0.1466, 0.2250],
[-0.0368, 0.0370, 0.0295],
[-0.3441, -0.3543, -0.2008],
[-0.0479, -0.0489, -0.0420],
[-0.0660, -0.0153, 0.0800],
[-0.0101, 0.0068, 0.0156],
[-0.0690, -0.0452, -0.0927],
[-0.0145, 0.0041, 0.0015],
[0.0421, 0.0451, 0.0373],
[0.0504, -0.0483, -0.0356],
[-0.0837, 0.0168, 0.0055]
]
latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388]
def __init__(self):
self.scale_factor = 1.0
self.taesd_decoder_name = "lighttaew2_2"
self.latents_mean = torch.tensor([
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
-0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502,
-0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230,
-0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748,
0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667,
]).view(1, self.latent_channels, 1, 1, 1)
self.latents_std = torch.tensor([
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013,
0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978,
0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659,
0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093,
0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887,
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class HunyuanImage21(LatentFormat):
latent_channels = 64
@ -554,105 +610,173 @@ class HunyuanImage21(LatentFormat):
latent_rgb_factors = [
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[0.0158, 0.0310, 0.0239],
[0.0409, 0.0516, 0.0201],
[0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[0.0039, 0.0368, 0.0104],
[0.0655, 0.0217, 0.0122],
[0.0490, 0.1638, 0.2053],
[0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0035, 0.0076, -0.0140],
[0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[0.0260, 0.0295, 0.0263],
[0.0250, 0.0333, 0.0341],
[0.0168, -0.0120, -0.0174],
[0.0226, 0.1037, 0.0114],
[0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[ 0.0234, 0.0179, 0.0201],
[ 0.0157, 0.0313, 0.0225],
[ 0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[ 0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[0.0234, 0.0179, 0.0201],
[0.0157, 0.0313, 0.0225],
[0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0755, -0.0489, -0.0635],
[ 0.0853, 0.0788, 0.1017],
[0.0853, 0.0788, 0.1017],
[-0.0272, -0.0294, -0.0471],
[ 0.0440, 0.0400, -0.0137],
[ 0.0335, 0.0317, -0.0036],
[0.0440, 0.0400, -0.0137],
[0.0335, 0.0317, -0.0036],
[-0.0344, -0.0621, -0.0984],
[-0.0127, -0.0630, -0.0620],
[-0.0648, 0.0360, 0.0924],
[-0.0648, 0.0360, 0.0924],
[-0.0781, -0.0801, -0.0409],
[ 0.0363, 0.0613, 0.0499],
[ 0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[ 0.0614, 0.1086, 0.0589],
[ 0.0428, 0.0350, 0.0205],
[ 0.0153, 0.0173, -0.0018],
[0.0363, 0.0613, 0.0499],
[0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[0.0614, 0.1086, 0.0589],
[0.0428, 0.0350, 0.0205],
[0.0153, 0.0173, -0.0018],
[-0.0288, -0.0455, -0.0091],
[ 0.0344, 0.0109, -0.0157],
[0.0344, 0.0109, -0.0157],
[-0.0205, -0.0247, -0.0187],
[ 0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[-0.0203, -0.0094, -0.0048],
[-0.0719, 0.0429, -0.0442],
[ 0.1042, 0.0497, 0.0356],
[-0.0719, 0.0429, -0.0442],
[0.1042, 0.0497, 0.0356],
[-0.0659, -0.0578, -0.0280],
[-0.0060, -0.0322, -0.0234]]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[0.0568, -0.0521, -0.0131],
[0.0014, 0.0735, 0.0326],
[0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[0.0530, 0.0413, 0.0253],
[0.0283, 0.0251, 0.0339],
[0.0277, -0.0372, -0.0093],
[0.0393, 0.0944, 0.1131],
[0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[0.0468, 0.0436, 0.0203],
[0.0354, 0.0439, -0.0233],
[0.0090, 0.0123, 0.0346],
[0.0382, 0.0029, 0.0217],
[0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[0.0414, 0.0479, 0.0529],
[0.0355, 0.0612, -0.0247],
[0.0147, 0.0264, 0.0174],
[0.0438, 0.0038, 0.0542],
[0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[0.0005, -0.0109, 0.0253],
[0.0296, 0.0591, 0.0353],
[0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[0.0005, -0.0106, 0.0242]
]
latent_rgb_factors_bias = [0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 1.0188137142395404
class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class ChromaRadiance(LatentFormat):
latent_channels = 3
def __init__(self):
self.latent_rgb_factors = [
# R G B
[ 1.0, 0.0, 0.0 ],
[ 0.0, 1.0, 0.0 ],
[ 0.0, 0.0, 1.0 ]
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]
]
def process_in(self, latent):

View File

@ -1,29 +1,33 @@
import torch
from torch import Tensor, nn
from ..flux.math import attention
from ..flux.layers import MLPEmbedder, RMSNorm, QKNorm, SelfAttention, ModulationOut
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
ModulationOut,
)
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
shift=tensor[:, offset: offset + 1, :],
scale=tensor[:, offset + 1: offset + 2, :],
gate=tensor[:, offset + 2: offset + 3, :],
)
class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=5, dtype=None, device=None, operations=None):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range(n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range(n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
@ -42,124 +46,6 @@ class Approximator(nn.Module):
return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -8,12 +8,15 @@ from einops import rearrange, repeat
from ..common_dit import pad_to_patch_size
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..flux.layers import EmbedND, timestep_embedding
from ..flux.layers import (
EmbedND,
timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
)
from .layers import (
DoubleStreamBlock,
LastLayer,
SingleStreamBlock,
Approximator,
ChromaModulationOut,
)
@ -37,6 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
txt_ids_dims: list
vec_in_dim: int
class Chroma(nn.Module):
@ -84,6 +89,7 @@ class Chroma(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -92,7 +98,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@ -173,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit:
double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i),
@ -216,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace:

View File

@ -10,10 +10,10 @@ from torch import Tensor, nn
from einops import repeat
from ..common_dit import pad_to_patch_size
from ..flux.layers import EmbedND
from ..flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from ..chroma.model import Chroma, ChromaParams
from ..chroma.layers import DoubleStreamBlock, SingleStreamBlock, Approximator
from ..chroma.layers import Approximator
from .layers import (
NerfEmbedder,
NerfGLUBlock,
@ -94,6 +94,7 @@ class ChromaRadiance(Chroma):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -106,6 +107,7 @@ class ChromaRadiance(Chroma):
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations,
)
for _ in range(params.depth_single_blocks)

View File

@ -47,15 +47,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class YakMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
self.act_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
if yak_mlp:
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@ -79,14 +108,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass
@ -97,11 +126,11 @@ class ModulationOut:
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2:
@ -128,80 +157,110 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k, img_v = torch.unbind(img_qkv, dim=0)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
attn = attention(q, k, v,
pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the img blocks
# todo: do we have to re-investigate this += versus img = img + ... op?
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
txt = txt + apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
txt = txt + apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
# calculate the txt blocks
# todo: do we have to re-investigate this += versus txt = txt + ... op?
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@ -221,6 +280,10 @@ class SingleStreamBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
modulation=True,
mlp_silu_act=False,
bias=True,
yak_mlp=False,
dtype=None,
device=None,
operations=None
@ -232,31 +295,57 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
if self.yak_mlp:
self.mlp_hidden_dim_first *= 2
self.mlp_act = nn.SiLU()
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
if self.modulation:
mod, _ = self.modulation(vec)
else:
mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = torch.unbind(qkv, dim=0)
del qkv
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
else:
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x = x + apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
@ -264,11 +353,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2:

View File

@ -7,15 +7,8 @@ from ... import model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@ -13,6 +13,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
)
from .. import common_dit
@ -33,6 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
yak_mlp: bool = False
txt_norm: bool = False
class Flux(nn.Module):
@ -60,13 +70,22 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None
self.double_blocks = nn.ModuleList(
[
@ -75,6 +94,10 @@ class Flux(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -83,13 +106,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.global_modulation:
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
)
def forward_orig(
self,
@ -122,9 +162,19 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.txt_norm is not None:
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@ -140,7 +190,10 @@ class Flux(nn.Module):
pe = None
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap_1(args):
out = {}
@ -181,7 +234,13 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1)
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap_2(args):
out = {}
@ -211,10 +270,10 @@ class Flux(nn.Module):
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img
def process_img(self, x, index=0, h_offset=0, w_offset=0):
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -226,10 +285,22 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options=None, **kwargs):
@ -249,16 +320,16 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x)
img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1]
if ref_latents is not None:
h = 0
w = 0
index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset")
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents:
if ref_latents_method == "index":
index += 1
index += self.params.ref_index_scale
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
@ -282,7 +353,12 @@ class Flux(nn.Module):
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.txt_ids_dims) > 0:
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h_orig, :w_orig]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:, :, :h_orig, :w_orig]

View File

@ -33,6 +33,8 @@ class HunyuanVideoParams:
guidance_embed: bool
byt5: bool
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
class SelfAttentionRef(nn.Module):
@ -153,7 +155,10 @@ class TokenRefiner(nn.Module):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1]
if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
@ -193,11 +198,15 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
params = HunyuanVideoParams(**kwargs)
self.params = params
self.patch_size = params.patch_size
self.in_channels = params.in_channels
self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -263,6 +272,18 @@ class HunyuanVideo(nn.Module):
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from ..wan.model import MLPProj
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None
def forward_orig(
self,
img: Tensor,
@ -273,7 +294,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor = None,
txt_byt5=None,
guidance: Tensor = None,
clip_fea=None,guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
disable_time_r=False,
@ -330,12 +351,31 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.cond_type_embedding is not None:
self.cond_type_embedding.to(txt.device)
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
txt = txt + cond_emb.to(txt.dtype)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
if clip_fea is not None:
txt_vision_states = self.vision_in(clip_fea)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@ -348,7 +388,10 @@ class HunyuanVideo(nn.Module):
attn_mask = None
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap_2(args):
out = {}
@ -370,7 +413,10 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -429,16 +475,16 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
return WrapperExecutor.new_class_executor(
self._forward,
self,
get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
bs = x.shape[0]
@ -448,5 +494,5 @@ class HunyuanVideo(nn.Module):
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
self.up = nn.ModuleList()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)
out = self.conv_out(F.silu(self.norm_out(x)))
return out
UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@ -1,8 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
from ..models.autoencoder import DiagonalGaussianRegularizer
from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, \
torch_cat_if_needed
from ...model_management import cast_to
from ...ops import disable_weight_init as ops
@ -14,11 +17,11 @@ class RMS_norm(nn.Module):
self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma
return F.normalize(x, dim=1) * self.scale * cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@ -28,11 +31,11 @@ class DnSmpl(nn.Module):
self.tds = tds
self.gs = fct * ic // oc
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tds else 1
h = self.conv(x)
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tds and self.refiner_vae:
if self.tds and self.refiner_vae and conv_carry_in is None:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -40,14 +43,7 @@ class DnSmpl(nn.Module):
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
h = h[:, :, 1:, :, :]
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@ -55,38 +51,36 @@ class DnSmpl(nn.Module):
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
x = x[:, :, 1:, :, :]
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
if h.shape[2] == 0:
return hf + xf
b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
return h + sc
b, ci, frms, ht, wd = x.shape
nf = frms // r1
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = x.shape
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
if self.tds and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -95,11 +89,11 @@ class UpSmpl(nn.Module):
self.tus = tus
self.rp = fct * oc // ic
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tus else 1
h = self.conv(x)
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tus and self.refiner_vae:
if self.tus and self.refiner_vae and conv_carry_in is None:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
@ -108,14 +102,7 @@ class UpSmpl(nn.Module):
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
h = h[:, :, 1:, :, :]
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
@ -126,29 +113,26 @@ class UpSmpl(nn.Module):
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
x = x[:, :, 1:, :, :]
sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc
x = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = x.shape
nc = c // (r1 * 2 * 2)
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
if self.tus and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class Encoder(nn.Module):
@ -162,7 +146,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -190,9 +174,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -203,31 +187,48 @@ class Encoder(nn.Module):
if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x)
if self.refiner_vae:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.ffactor_temporal:
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
x = xl
else:
x = [x]
out = []
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
conv_carry_in = None
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [x1]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
out = torch_cat_if_needed(out, dim=2)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
if self.refiner_vae:
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
@ -242,7 +243,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = VideoConv3d
conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -252,9 +253,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@ -278,24 +279,34 @@ class Decoder(nn.Module):
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z):
if self.refiner_vae:
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
if self.refiner_vae:
x = torch.split(x, 2, dim=2)
else:
x = [x]
out = []
out = self.conv_out(F.silu(self.norm_out(x)))
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [F.silu(self.norm_out(x1))]
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
del x
out = torch_cat_if_needed(out, dim=2)
if not self.refiner_vae:
if z.shape[-3] == 1:

View File

@ -0,0 +1,413 @@
import torch
from torch import nn
import math
import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
heads=heads,
skip_reshape=True,
transformer_options=transformer_options
)
def apply_scale_shift_norm(norm, x, scale, shift):
return torch.addcmul(shift, norm(x), scale + 1.0)
def apply_gate_sum(x, out, gate):
return torch.addcmul(x, gate, out)
def get_shift_scale_gate(params):
shift, scale, gate = torch.chunk(params, 3, dim=-1)
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0):
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C
B, T, H, W, dim = x.shape
pt, ph, pw = self.patch_size
x = x.view(
B,
T // pt, pt,
H // ph, ph,
W // pw, pw,
dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
return self.out_layer(self.activation(x))
class SelfAttention(nn.Module):
def __init__(self, num_channels, head_dim, operation_settings=None):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.head_dim = head_dim
operations = operation_settings.get("operations")
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 2
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = []
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
def get_qkv(self, x, context):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
return q, k, v
def forward(self, x, context, transformer_options={}):
q, k, v = self.get_qkv(x, context)
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.GELU()
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def _forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
chunks = torch.chunk(x, self.num_chunks, dim=1)
output_chunks = []
for chunk in chunks:
output_chunks.append(self._forward(chunk))
return torch.cat(output_chunks, dim=1)
def forward(self, x):
if x.shape[1] > 8192:
return self._forward_chunked(x)
else:
return self._forward(x)
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, visual_embed, time_embed):
B, T, H, W, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
scale = scale[:, None, None, None, :]
shift = shift[:, None, None, None, :]
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
x = self.out_layer(visual_embed)
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
x = x.view(
B, T, H, W,
out_dim,
self.patch_size[0], self.patch_size[1], self.patch_size[2]
)
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs, transformer_options={}):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, freqs, transformer_options=transformer_options)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = get_shift_scale_gate(ff_params)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# feed forward
shift, scale, gate = get_shift_scale_gate(ff_params)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class Kandinsky5(nn.Module):
def __init__(
self,
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
dtype=None, device=None, operations=None, **kwargs
):
super().__init__()
head_dim = sum(axes_dims)
self.rope_scale_factor = rope_scale_factor
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_embed_dim = visual_embed_dim
self.dtype = dtype
self.device = device
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
self.text_transformer_blocks = nn.ModuleList(
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
)
self.visual_transformer_blocks = nn.ModuleList(
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
)
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
steps = seq_len if steps is None else steps
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
return freqs
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
else:
rope_scale_factor = self.rope_scale_factor
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
if h * w >= 14080:
rope_scale_factor = (1.0, 3.16, 3.16)
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
return freqs
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
original_dims = x.ndim
if original_dims == 4:
x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
if original_dims == 4:
out = out.squeeze(2)
return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -1,12 +1,12 @@
import torch
from torch import nn
from ..common_dit import rms_norm
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
import torch
from torch import nn
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from ..common_dit import rms_norm
from ..flux.math import apply_rope1
from ..modules.attention import optimized_attention, optimized_attention_masked
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
@ -181,10 +181,11 @@ class AdaLayerNormSingle(nn.Module):
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]: #, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
# todo: whats going on with the signature?
return self.linear(self.silu(embedded_timestep)), embedded_timestep
@ -240,20 +241,6 @@ class FeedForward(nn.Module):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): # TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
@ -285,8 +272,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@ -312,12 +299,17 @@ class BasicTransformerBlock(nn.Module):
transformer_options = {}
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
attn1_input = rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
y = rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
return x
@ -336,41 +328,35 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=None):
if max_pos is None:
max_pos = [20, 2048, 2048]
dtype = torch.float32 # self.dtype
dtype = torch.float32
device = indices_grid.device
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
start = 1
end = theta
device = fractional_positions.device
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
# Pad if dim is not divisible by 6
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
class LTXVModel(torch.nn.Module):
@ -515,7 +501,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)
x = self.patchifier.unpatchify(

View File

@ -0,0 +1,113 @@
import torch
from torch import nn
from .model import JointTransformerBlock
class ZImageControlTransformerBlock(JointTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
operation_settings=None,
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class ZImage_Control(torch.nn.Module):
def __init__(
self,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
multiple_of: int = 256,
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0
self.control_in_dim = 16
n_refiner_layers = 2
self.n_control_layers = 6
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=i,
operation_settings=operation_settings,
)
for i in range(self.n_control_layers)
]
)
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
f_patch_size = 1
pH = pW = patch_size
B, C, H, W = control_context.shape
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -12,26 +12,34 @@ from ..modules.diffusionmodules.mmdit import TimestepEmbedder
from ..modules.attention import optimized_attention_masked
from ..flux.layers import EmbedND
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..flux.math import apply_rope
def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))
#############################################################################
# Core NextDiT Model #
#############################################################################
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
operation_settings={},
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
out_bias: bool = False,
operation_settings=None,
):
"""
Initialize the Attention module.
@ -43,6 +51,8 @@ class JointAttention(nn.Module):
"""
super().__init__()
if operation_settings is None:
operation_settings = {}
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_local_heads = n_heads
self.n_local_kv_heads = self.n_kv_heads
@ -59,7 +69,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim,
dim,
bias=False,
bias=out_bias,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
)
@ -70,41 +80,12 @@ class JointAttention(nn.Module):
else:
self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options={},
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options=None,
) -> torch.Tensor:
"""
@ -116,6 +97,8 @@ class JointAttention(nn.Module):
Returns:
"""
if transformer_options is None:
transformer_options = {}
bsz, seqlen, _ = x.shape
xq, xk, xv = torch.split(
@ -134,8 +117,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
xq, xk = apply_rope(xq, xk, freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
@ -148,12 +130,12 @@ class JointAttention(nn.Module):
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
operation_settings={},
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
operation_settings=None,
):
"""
Initialize the FeedForward module.
@ -169,6 +151,8 @@ class FeedForward(nn.Module):
"""
super().__init__()
# custom dim factor multiplier
if operation_settings is None:
operation_settings = {}
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
@ -197,7 +181,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@ -205,17 +189,19 @@ class FeedForward(nn.Module):
class JointTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
operation_settings={},
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings=None,
) -> None:
"""
Initialize a TransformerBlock.
@ -233,12 +219,14 @@ class JointTransformerBlock(nn.Module):
"""
super().__init__()
if operation_settings is None:
operation_settings = {}
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
hidden_dim=dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings,
@ -252,24 +240,35 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
if z_image_modulation:
self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear(
min(dim, 256),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
self,
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
transformer_options=None,
):
"""
Perform a forward pass through the TransformerBlock.
@ -283,32 +282,34 @@ class JointTransformerBlock(nn.Module):
feedforward layers.
"""
if transformer_options is None:
transformer_options = {}
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(
clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
)
))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
self.attention(
clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
))
)
x = x + self.ffn_norm2(
self.feed_forward(
@ -323,8 +324,10 @@ class FinalLayer(nn.Module):
The final layer of NextDiT.
"""
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings=None):
super().__init__()
if operation_settings is None:
operation_settings = {}
self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size,
elementwise_affine=False,
@ -340,10 +343,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"),
)
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(hidden_size, 1024),
min(hidden_size, min_mod),
hidden_size,
bias=True,
device=operation_settings.get("device"),
@ -364,25 +372,30 @@ class NextDiT(nn.Module):
"""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
dim: int = 4096,
n_layers: int = 32,
n_refiner_layers: int = 2,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
image_model=None,
device=None,
dtype=None,
operations=None,
self,
patch_size: int = 2,
in_channels: int = 4,
dim: int = 4096,
n_layers: int = 32,
n_refiner_layers: int = 2,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
operations=None,
) -> None:
super().__init__()
self.dtype = dtype
@ -390,6 +403,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels,
@ -411,6 +426,7 @@ class NextDiT(nn.Module):
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
@ -434,7 +450,7 @@ class NextDiT(nn.Module):
]
)
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
@ -446,6 +462,31 @@ class NextDiT(nn.Module):
),
)
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@ -457,23 +498,29 @@ class NextDiT(nn.Module):
ffn_dim_multiplier,
norm_eps,
qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings,
)
for layer_id in range(n_layers)
]
)
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads
def unpatchify(
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
) -> List[torch.Tensor]:
"""
x: (N, T, patch_size**2 * C)
@ -498,101 +545,61 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options=None
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
if transformer_options is None:
transformer_options = {}
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
dtype = x[0].dtype
if cap_mask is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
else:
l_effective_cap_len = [num_tokens] * bsz
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
if cap_mask is not None and not torch.is_floating_point(cap_mask):
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x]
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
B, C, H, W = x.shape
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max(
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
)
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
H_tokens, W_tokens = H // pH, W // pW
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
padded_img_mask = None
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -603,7 +610,9 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@ -615,21 +624,36 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features
"""
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
for i, layer in enumerate(self.layers):
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
return -x
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img

View File

@ -1,10 +1,12 @@
import logging
import math
from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union, Callable
from typing import Any, Dict, Tuple, Union, Callable, Optional
import torch
from einops import rearrange
from ...model_management import cast_to
from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma
from ..util import instantiate_from_config, get_obj_from_str
@ -12,6 +14,7 @@ from ... import ops
logger = logging.getLogger(__name__)
class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False):
super().__init__()
@ -20,7 +23,7 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Optional[dict]]:
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
@ -28,13 +31,15 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
z = posterior.mode()
return z, None
class EmptyRegularizer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Optional[dict]]:
return z, None
class AbstractAutoencoder(torch.nn.Module):
"""
This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators,
@ -181,8 +186,27 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if ddconfig.get("batch_norm_latent", False):
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
self.bn.eval()
else:
self.bn = None
def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params()
return params
def encode(
self, x: torch.Tensor, return_reg_log: bool = False
self, x: torch.Tensor, return_reg_log: bool = False,
unregularized: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
if self.max_batch_size is None:
z = self.encoder(x)
@ -199,11 +223,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0)
z, reg_log = self.regularization(z)
if self.bn is not None:
z = rearrange(z,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = torch.nn.functional.batch_norm(z,
cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
momentum=self.bn_momentum,
eps=self.bn_eps)
if return_reg_log:
return z, reg_log
return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.bn is not None:
s = torch.sqrt(cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
m = cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
z = z * s + m
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
if self.max_batch_size is None:
dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs)

View File

@ -557,6 +557,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -581,6 +582,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),

View File

@ -213,12 +213,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size

View File

@ -18,6 +18,13 @@ if model_management.xformers_enabled_vae():
import xformers.ops # pylint: disable=import-error
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -41,13 +48,43 @@ def get_timestep_embedding(timesteps, embedding_dim):
def nonlinearity(x):
# swish
return torch.nn.functional.silu(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode='replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode='replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@ -96,29 +133,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
results = []
if conv_carry_in is None:
first = x[:, :, :1, :, :]
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
x = x[:, :, 1:, :, :]
if x.shape[2] > 0:
results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = self.conv(x)
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@ -134,17 +166,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
def forward(self, x):
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
if x.ndim == 4:
if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@ -190,23 +225,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb=None):
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
h = self.swish(h)
h = self.conv1(h)
h = [self.swish(h)]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:, :, None, None]
h = self.norm2(h)
h = self.swish(h)
h = self.dropout(h)
h = self.conv2(h)
h = [self.dropout(h)]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@ -290,6 +325,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@ -300,6 +336,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logger.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@ -529,9 +567,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.carried = False
if conv3d:
conv_op = VideoConv3d
if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -544,6 +587,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@ -570,10 +614,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -599,15 +648,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
if self.carried:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.time_compress:
tc = self.time_compress
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim=2)
x = xl
else:
x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [x1]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 # carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions - 1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle
h = self.mid.block_1(h, temb)
@ -616,15 +692,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h = [nonlinearity(h)]
h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@ -638,12 +714,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.carried = False
if conv3d:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
if not attn_resolutions and resnet_op == ResnetBlock:
conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -718,29 +800,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h = self.conv_in(z)
h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [h]
out = []
conv_carry_in = None
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
for i, h1 in enumerate(h):
conv_carry_out = []
if i == len(h) - 1:
conv_carry_out = None
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0:
assert i == 0 # carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0:
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h1)
h1 = [nonlinearity(h1)]
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
if self.tanh_out:
h1 = torch.tanh(h1)
out.append(h1)
conv_carry_in = conv_carry_out
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h
out = torch_cat_if_needed(out, dim=2)
return out

View File

@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@ -10,6 +10,7 @@ from ..flux.layers import EmbedND
from ..lightricks.model import TimestepEmbedding, Timesteps
from ..modules.attention import optimized_attention_masked
from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from ..flux.math import apply_rope1
class GELU(nn.Module):
@ -137,33 +138,34 @@ class Attention(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
if transformer_options is None:
transformer_options = {}
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -239,10 +241,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated,
@ -251,16 +253,20 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_normed2 = self.img_norm2(hidden_states)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states
@ -421,7 +427,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
@ -441,7 +447,10 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -246,6 +246,7 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32
# self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
@ -615,7 +616,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -628,10 +629,22 @@ class WanModel(torch.nn.Module):
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
@ -661,7 +674,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):

View File

@ -326,6 +326,22 @@ def model_lora_keys_unet(model, key_map=None):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
return key_map

View File

@ -53,6 +53,7 @@ from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentati
from .ldm.chroma_radiance import model as chroma_radiance
from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel
from .ldm.pixart.pixartms import PixArtMS
from .ldm.kandinsky5.model import Kandinsky5
from .ldm.qwen_image.model import QwenImageTransformer2DModel
from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel
from .ldm.wan.model_animate import AnimateWanModel
@ -149,7 +150,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.operations = operations
@ -216,8 +217,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs):
return timestep
@ -343,10 +350,6 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@ -921,12 +924,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"]
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = conds.CONDRegular(attention_mask)
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
if mask_ref_size is not None:
# the model will pad to the patch size, and then divide
# essentially dividing and rounding up
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5)
if guidance is not None:
@ -948,7 +952,19 @@ class Flux(BaseModel):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
target_text_len = 512
if cross_attn.shape[1] < target_text_len:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = conds.CONDRegular(cross_attn)
return out
@ -1135,6 +1151,12 @@ class Lumina2(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled)
return out
@ -1580,3 +1602,144 @@ class HunyuanImage21Refiner(HunyuanImage21):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = conds.CONDConstant(True)
return out
class HunyuanVideo15(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 # noise 32 img cond 32 + mask 1
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance]))
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = conds.CONDRegular(clip_vision_output.last_hidden_state)
return out
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=model_management.intermediate_device())
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
# image = self.process_latent_in(image) # scaling wasn't applied in reference code
image = utils.resize_to_batch_size(image, noise.shape[0])
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
else:
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = conds.CONDConstant(False)
return out
class Kandinsky5(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=Kandinsky5)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
device = kwargs["device"]
image = torch.zeros_like(noise)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = conds.CONDRegular(cross_attn)
time_dim_replace = kwargs.get("time_dim_replace", None)
if time_dim_replace is not None:
out['time_dim_replace'] = conds.CONDRegular(self.process_latent_in(time_dim_replace))
return out
class Kandinsky5Image(Kandinsky5):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
return None

View File

@ -8,6 +8,7 @@ import torch
from . import supported_models, utils
from . import supported_models_base
from .gguf import GGMLOps
from .utils import detect_layer_quantization
logger = logging.getLogger(__name__)
@ -180,30 +181,71 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
# HunyuanVideo 1.5
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["use_cond_type_embedding"] = True
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
else:
dit_config["vision_in_dim"] = None
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): # Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
dit_config["image_model"] = "flux"
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
dit_config["axes_dim"] = [32, 32, 32, 32]
dit_config["num_heads"] = 48
dit_config["mlp_ratio"] = 3.0
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["num_heads"] = 24
dit_config["mlp_ratio"] = 4.0
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
patch_size = 2
dit_config["hidden_size"] = 3072
dit_config["context_in_dim"] = 4096
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
w = state_dict[in_key]
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
else:
dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: # Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
@ -226,6 +268,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: # Genmo mochi preview
@ -372,14 +419,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@ -556,6 +623,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
if model_dim == 2560: # lite image
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -699,16 +784,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = supported_models_base.BASE(unet_config)
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
quant_config = detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config:
model_config.quant_config = quant_config
logging.info("Detected mixed precision quantization")
if metadata is not None and "format" in metadata and metadata["format"] == "gguf":
model_config.custom_operations = GGMLOps

View File

@ -114,6 +114,7 @@ if args.deterministic:
directml_device = None
if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml # pylint: disable=import-error
device_index = args.directml
@ -380,15 +381,20 @@ except:
pass
SUPPORT_FP8_OPS = args.supports_fp8_compute
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
try:
if is_amd():
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logger.debug("AMD arch: {}".format(arch))
logger.debug("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@ -557,6 +563,7 @@ class LoadedModel:
if use_more_vram == 0:
use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@ -811,8 +818,11 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@ -1149,13 +1159,6 @@ def device_supports_non_blocking(device):
return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last():
if args.force_channels_last:
return True
@ -1165,57 +1168,77 @@ def force_channels_last():
STREAMS = {}
NUM_STREAMS = 1
if args.async_offload:
NUM_STREAMS = 2
NUM_STREAMS = 0
if args.async_offload is not None:
NUM_STREAMS = args.async_offload
else:
# Enable by default on Nvidia
if is_nvidia():
NUM_STREAMS = 2
if args.disable_async_offload:
NUM_STREAMS = 0
if NUM_STREAMS > 0:
logger.debug("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
if NUM_STREAMS == 0:
return None
if torch.compiler.is_compiling():
return None
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
# Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
s1 = torch.cuda.Stream(device=device, priority=0)
s1.as_context = torch.cuda.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
s1 = torch.xpu.Stream(device=device, priority=0)
s1.as_context = torch.xpu.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None
def sync_stream(device, stream):
if stream is None:
if stream is None or current_stream(device) is None:
return
if is_device_cuda(device):
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
current_stream(device).wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
@ -1224,12 +1247,18 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
@ -1243,6 +1272,85 @@ def cast_to_device(tensor, device, dtype, copy=False):
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
# NOTE: Cuda does detect when a tensor is already pinned and would
# error below, but there are proven cases where this also queues an error
# on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled():
return args.use_sage_attention
@ -1531,7 +1639,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
arch = torch.cuda.get_device_properties(device).gcnArchName
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
if manual_cast:
return True
return False
@ -1607,6 +1715,23 @@ def extended_fp16_support():
return True
LORA_COMPUTE_DTYPES = {}
def lora_compute_dtype(device):
dtype = LORA_COMPUTE_DTYPES.get(device, None)
if dtype is not None:
return dtype
if should_use_fp16(device):
dtype = torch.float16
else:
dtype = torch.float32
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def soft_empty_cache(force=False):
with model_management_lock:
_soft_empty_cache(force=force)

View File

@ -353,6 +353,9 @@ class MemoryMeasurements:
current_weight_patches_uuid: Any = None
_device: torch.device | None = None
def __init__(self):
self.model_offload_buffer_memory = None
@property
def device(self) -> torch.device:
if isinstance(self.model, DeviceSettable):

View File

@ -40,8 +40,10 @@ from .component_model.deprecation import _deprecate_method
from .float import stochastic_rounding
from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps
from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks
from .lora import calculate_weight
from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport
from .model_base import BaseModel
from .model_management import lora_compute_dtype
from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport
from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
@ -144,27 +146,23 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
def __call__(self, weight):
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
return calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: # intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
out = lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
# The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@ -253,7 +251,6 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
self._model_options: ModelOptions = {"transformer_options": {}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
@ -262,6 +259,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.patches_uuid: uuid.UUID = uuid.uuid4()
self.ckpt_name = ckpt_name
self._memory_measurements = MemoryMeasurements(self.model)
self.pinned = set()
self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {}
self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks()
@ -322,17 +320,28 @@ class ModelPatcher(ModelManageable, PatchSupport):
def lowvram_patch_counter(self):
return self._memory_measurements.lowvram_patch_counter
@property
def model_offload_buffer_memory(self) -> int:
return self._memory_measurements.model_offload_buffer_memory
@model_offload_buffer_memory.setter
def model_offload_buffer_memory(self, value):
self._memory_measurements.model_offload_buffer_memory = value
def model_size(self):
if self.size > 0:
return self.size
self.size = model_management.module_size(self.model)
return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self):
return self._memory_measurements.model_loaded_weight_memory
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n._memory_measurements = self._memory_measurements
n.ckpt_name = self.ckpt_name
n.patches = {}
@ -346,6 +355,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n._parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights
@ -430,8 +440,11 @@ class ModelPatcher(ModelManageable, PatchSupport):
return True
def memory_required(self, input_shape) -> int:
assert isinstance(self.model, BaseModel)
return self.model.memory_required(input_shape=input_shape)
if isinstance(self.model, BaseModel):
return self.model.memory_required(input_shape=input_shape)
else:
# todo: some other heuristic to determine memory required
raise ValueError("unexpected call to memory required on object that doesn't have a BaseModel but is using ModelPatcher")
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
@ -504,6 +517,18 @@ class ModelPatcher(ModelManageable, PatchSupport):
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@ -686,10 +711,11 @@ class ModelPatcher(ModelManageable, PatchSupport):
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = lora_compute_dtype(device_to)
if device_to is not None:
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
temp_weight = model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
@ -703,6 +729,21 @@ class ModelPatcher(ModelManageable, PatchSupport):
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self) -> list[LoadingListItem]:
loading = []
for n, m in self.model.named_modules():
@ -715,7 +756,16 @@ class ModelPatcher(ModelManageable, PatchSupport):
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append(LoadingListItem(model_management.module_size(m), n, m, params))
module_mem = model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append(LoadingListItem(module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -727,25 +777,30 @@ class ModelPatcher(ModelManageable, PatchSupport):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list()
load_completely: list[LoadingListItem] = []
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
n = x.name
m = x.module
params = x.params
module_mem = x.module_size
for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem + sum([x1[1] for x1 in loading[i + 1:i + 1 + model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
if not lowvram_fits:
offload_buffer = potential_offload
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): # Already lowvramed
continue
@ -771,13 +826,16 @@ class ModelPatcher(ModelManageable, PatchSupport):
patch_counter += 1
cast_weight = True
offloaded.append((module_mem, n, m, params))
else:
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
if full_load or lowvram_fits:
mem_counter += module_mem
load_completely.append(LoadingListItem(module_mem, n, m, params))
else:
offload_buffer = potential_offload
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -802,7 +860,11 @@ class ModelPatcher(ModelManageable, PatchSupport):
continue
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if model_management.is_device_cuda(device_to):
torch.cuda.synchronize()
models_loaded_regularly.append("name={} module={}".format(n, m))
m.comfy_patched_weights = True
@ -810,11 +872,21 @@ class ModelPatcher(ModelManageable, PatchSupport):
for x in load_completely:
x.module.to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0:
logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}")
if hasattr(self.model, "model_lovram"):
self.model.model_lowvram = True
self._memory_measurements.model_lowvram = True
else:
logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}")
if hasattr(self.model, "model_lovram"):
self.model.model_lowvram = False
self._memory_measurements.model_lowvram = False
if full_load:
self.model.to(device_to)
@ -846,6 +918,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = mem_counter
self._memory_measurements.model_offload_buffer_memory = offload_buffer
self._memory_measurements.current_weight_patches_uuid = self.patches_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
@ -881,6 +954,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
p.patches = []
if unpatch_weights:
self.unpatch_hooks()
self.unpin_all_weights()
if self._memory_measurements.model_lowvram:
for m in self.model.modules():
move_weight_functions(m, device_to)
@ -907,6 +981,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.model.to(device_to)
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0
self._memory_measurements.model_offload_buffer_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
@ -918,7 +993,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
freed_layers: list[str] = []
with self.use_ejected():
hooks_unpatched = False
@ -926,13 +1001,18 @@ class ModelPatcher(ModelManageable, PatchSupport):
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self._memory_measurements.model_offload_buffer_memory
if len(unload_list) > 0:
NS = model_management.NUM_STREAMS
offload_weight_factor = [min(offload_buffer / (NS + 1), unload_list[0][1])] * NS
for unload in unload_list:
if memory_to_free < memory_freed:
if memory_to_free + offload_buffer - self._memory_measurements.model_offload_buffer_memory < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
params = unload[3]
module_offload_mem, module_mem, n, m, params = unload
potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -963,13 +1043,19 @@ class ModelPatcher(ModelManageable, PatchSupport):
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True
if cast_weight:
@ -978,12 +1064,19 @@ class ModelPatcher(ModelManageable, PatchSupport):
m.comfy_patched_weights = False
memory_freed += module_mem
freed_layers.append(n)
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
logger.debug("freed {}".format(natsorted(freed_layers)))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
logger.debug(f"Freed {natsorted(freed_layers)}")
self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory -= memory_freed
self._memory_measurements.model_offload_buffer_memory = offload_buffer
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False) -> int:
@ -996,6 +1089,9 @@ class ModelPatcher(ModelManageable, PatchSupport):
extra_memory += (used - self._memory_measurements.model_loaded_weight_memory)
self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False
if not self._memory_measurements.model_lowvram and self._memory_measurements.model_loaded_weight_memory > 0:
return 0
@ -1399,4 +1495,5 @@ class ModelPatcher(ModelManageable, PatchSupport):
self.clear_cached_hook_weights()
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)

91
comfy/nested_tensor.py Normal file
View File

@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@ -745,8 +745,10 @@ class LoraLoaderModelOnly(LoraLoader):
class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
@staticmethod
def vae_list():
def vae_list(s):
vaes = get_filename_list_with_downloadable("vae", KNOWN_VAES)
approx_vaes = get_filename_list_with_downloadable("vae_approx", KNOWN_APPROX_VAES)
sdxl_taesd_enc = False
@ -775,6 +777,11 @@ class VAELoader:
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
else:
for tae in s.video_taes:
if v.startswith(tae):
vaes.append(v)
if sd1_taesd_dec and sd1_taesd_enc:
vaes.append("taesd")
if sdxl_taesd_dec and sdxl_taesd_enc:
@ -818,8 +825,7 @@ class VAELoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae_name": (s.vae_list(),)}}
return {"required": {"vae_name": (s.vae_list(s),)}}
RETURN_TYPES = ("VAE",)
FUNCTION = "load_vae"
@ -831,10 +837,13 @@ class VAELoader:
if vae_name == "pixel_space":
sd_ = {}
sd_["pixel_space_vae"] = torch.tensor(1.0)
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
elif vae_name in self.image_taes:
sd_ = self.load_taesd(vae_name)
else:
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
if os.path.splitext(vae_name)[0] in self.video_taes:
vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name)
else:
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name)
vae.throw_exception_if_invalid()
@ -1016,7 +1025,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"],),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"],),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -1046,7 +1055,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": {"clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": (
get_filename_list_with_downloadable("text_encoders"),),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"],),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"],),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -2003,6 +2012,11 @@ class ImageBatch:
CATEGORY = "image"
def batch(self, image1, image2):
if image1.shape[-1] != image2.shape[-1]:
if image1.shape[-1] > image2.shape[-1]:
image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0)
else:
image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0)
if image1.shape[1:] != image2.shape[1:]:
image2 = utils.common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1)
s = torch.cat((image1, image2), dim=0)

View File

@ -17,20 +17,22 @@
"""
import contextlib
import logging
import torch
from torch import Tensor
from typing import Optional, Type, Union
import torch
from torch import Tensor
from . import model_management, rmsnorm
from .interruption import throw_exception_if_processing_interrupted
from .cli_args import args, PerformanceFeature
from .execution_context import current_execution_context
from .float import stochastic_rounding
from .interruption import throw_exception_if_processing_interrupted
logger = logging.getLogger(__name__)
_RUN_EVERY_OP_ENABLED = model_management.torch_version_numeric >= (2, 5)
import json
def run_every_op():
global _RUN_EVERY_OP_ENABLED
@ -49,7 +51,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs):
try:
if torch.cuda.is_available():
if torch.cuda.is_available() and model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error
import inspect
@ -82,7 +84,8 @@ except Exception as exc_info:
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if model_management.is_nvidia():
if torch.backends.cudnn.version() >= 91002 and model_management.torch_version_numeric >= (2, 9) and model_management.torch_version_numeric <= (2, 10):
cudnn_version = torch.backends.cudnn.version()
if (cudnn_version >= 91002 and cudnn_version < 91500) and model_management.torch_version_numeric >= (2, 9) and model_management.torch_version_numeric <= (2, 10):
# TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logger.debug("working around nvidia conv3d memory bug.")
@ -96,41 +99,81 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None:
if dtype is None:
dtype = input.dtype
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
offload_stream = model_management.get_offload_stream(device)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
bias = None
non_blocking = True if torch.jit.is_tracing() or torch.jit.is_scripting() else model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
# todo: how is wf_context used?
non_blocking = model_management.device_supports_non_blocking(device)
has_function = len(s.weight_function) > 0
weight = model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
with wf_context:
for f in s.weight_function:
weight = f(weight)
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
model_management.sync_stream(device, offload_stream)
return weight, bias
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
# Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
if os is None:
return
if weight_a is not None:
device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(model_management.current_stream(device))
class SkipInit:
@ -191,8 +234,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -206,8 +251,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -221,8 +268,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -245,8 +294,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -260,8 +311,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -276,11 +329,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -296,11 +352,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else:
weight = None
return rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
bias = None
offload_stream = None
x = rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -319,10 +379,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -341,10 +403,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -362,8 +426,10 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
@ -417,48 +483,33 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple):
o = o[0]
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
uncast_bias_weight(self, w, bias, offload_stream)
return o
return None
@ -471,7 +522,7 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
if not self.training:
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
try:
out = fp8_linear(self, input)
if out is not None:
@ -479,67 +530,16 @@ class fp8_ops(manual_cast):
except Exception as e:
logger.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
class scaled_fp8_op_base(manual_cast):
pass
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logger.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(scaled_fp8_op_base):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
if weight.numel() < input.numel(): # TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
@ -565,14 +565,178 @@ else:
Operations = Type[Union[manual_cast, fp8_ops, disable_weight_init, skip_init, scaled_fp8_op_base]]
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8: Optional[torch.dtype] = None, inference_mode: Optional[bool] = None) -> Operations:
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if scale is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None, inference_mode: Optional[bool] = None):
if inference_mode is None:
# todo: check a context here, since this isn't being used by any callers yet
inference_mode = current_execution_context().inference_mode
fp8_compute = model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
fp8_compute = model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if (
fp8_compute and

577
comfy/quant_ops.py Normal file
View File

@ -0,0 +1,577 @@
import torch
import logging
from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if scale is not None:
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
else:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -5,6 +5,21 @@ from . import model_management
from . import samplers
from . import utils
from .component_model.deprecation import _deprecate_method
from .nested_tensor import NestedTensor
def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
@ -13,36 +28,41 @@ def prepare_noise(latent_image, seed, noise_inds=None):
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises
def fix_empty_latent_channels(model, latent_image):
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") # Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
latent_image = latent_image.unsqueeze(2)
return latent_image
@_deprecate_method(version="0.3.2", message="Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed")
def prepare_sampling(model, noise_shape, positive, negative, noise_mask):
pass
return model, positive, negative, noise_mask, []
@_deprecate_method(version="0.3.2", message="Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed")
def cleanup_additional_models(models):
pass
def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):
sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
@ -50,6 +70,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
samples = samples.to(model_management.intermediate_device())
return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(model_management.intermediate_device())

View File

@ -24,7 +24,7 @@ from .model_management_types import ModelOptions
from .model_patcher import ModelPatcher
from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES
from .context_windows import ContextHandlerABC
from .utils import common_upscale
from .utils import common_upscale, pack_latents, unpack_latents
from .patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP
from .component_model import module_property
@ -827,7 +827,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds:
conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@ -837,7 +837,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'):
for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
# make sure each cond area has an opposite one with the same area
for k in conds:
@ -1008,11 +1008,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed):
def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: # Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@ -1026,7 +1026,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
def outer_sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device
@ -1040,7 +1040,7 @@ class CFGGuider:
try:
self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally:
self.model_patcher.cleanup()
@ -1053,6 +1053,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0:
return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = pack_latents(latent_image.unbind())
noise, _ = pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {}
for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@ -1072,7 +1078,7 @@ class CFGGuider:
self,
patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
except ValueError as exc_info:
if "fp8e4nv" in str(exc_info):
logger.error(f"Load the weights for model {self.model_patcher} as fp8_e5m2 to use floating point 8-bit inference with torch.compile and triton on Ampere architecture")
@ -1084,6 +1090,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches()
del self.conds
if len(latent_shapes) > 1:
output = NestedTensor(unpack_latents(output, latent_shapes))
return output

View File

@ -6,12 +6,11 @@ import logging
import math
import os
import os.path
import torch
import yaml
from enum import Enum
from typing import Any, Optional
from humanize import naturalsize
import torch
import yaml
from . import clip_vision
from . import diffusers_convert
@ -34,14 +33,15 @@ from .ldm.flux.redux import ReduxImageEncoder
from .ldm.genmo.vae import model as genmo_model
from .ldm.hunyuan3d.vae import ShapeVAE
from .ldm.lightricks.vae import causal_video_autoencoder as lightricks
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.mmaudio.vae.autoencoder import AudioAutoencoder
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
from .ldm.wan import vae as wan_vae
from .ldm.wan import vae2_2 as wan_vae2_2
from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip
from .lora_convert import convert_lora
from .model_management import load_models_gpu
from .model_management import load_models_gpu, module_size
from .model_patcher import ModelPatcher
from .pixel_space_convert import PixelspaceConversionVAE
from .t2i_adapter import adapter
from .taesd import taesd
from .text_encoders import ace
@ -50,21 +50,25 @@ from .text_encoders import cosmos
from .text_encoders import flux
from .text_encoders import genmo
from .text_encoders import hidream
from .text_encoders import hunyuan_video
from .text_encoders import hunyuan_image
from .text_encoders import hunyuan_video
from .text_encoders import hydit
from .text_encoders import kandinsky5
from .text_encoders import long_clipl
from .text_encoders import lt
from .text_encoders import lumina2
from .text_encoders import omnigen2
from .text_encoders import ovis
from .text_encoders import pixart_t5
from .text_encoders import qwen_image
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import wan
from .utils import ProgressBar, FileMetadata
from .pixel_space_convert import PixelspaceConversionVAE
from .text_encoders import z_image
from .utils import ProgressBar, FileMetadata, state_dict_prefix_replace
from .taesd.taehv import TAEHV
from .latent_formats import HunyuanVideo15, HunyuanVideo
logger = logging.getLogger(__name__)
@ -101,7 +105,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_
class CLIP:
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, model_options={}):
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, state_dict=[], model_options={}):
if tokenizer_data is None:
tokenizer_data = dict()
if no_init:
@ -137,6 +141,27 @@ class CLIP:
self.patcher.hook_mode = EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if len(state_dict) > 0:
if isinstance(state_dict, list):
for c in state_dict:
m, u = self.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
else:
m, u = self.load_sd(state_dict, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@ -156,6 +181,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -199,6 +227,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@ -246,6 +275,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@ -310,6 +340,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.size = None
self.downscale_index_formula = None
self.upscale_index_formula = None
@ -369,7 +400,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32:
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -395,6 +426,17 @@ class VAE:
self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
if 'bn.running_mean' in sd:
ddconfig["batch_norm_latent"] = True
self.downscale_ratio *= 2
self.upscale_ratio *= 2
self.latent_channels *= 4
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else:
@ -454,20 +496,20 @@ class VAE:
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.latent_channels = 64
self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3
self.not_video = True
self.not_video = False
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@ -479,8 +521,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
# This is likely to significantly over-estimate with single image or low frame counts as the
# implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@ -509,13 +553,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
else: # Wan 2.1 VAE
dim = sd["decoder.head.0.gamma"].shape[0]
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = 16
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = wan_vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
@ -593,6 +638,35 @@ class VAE:
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.crop_input = False
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
self.latent_channels = sd["decoder.1.weight"].shape[1]
self.latent_dim = 3
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16)
if self.latent_channels == 48: # Wan 2.2
self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.process_output = lambda image: image
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=HunyuanVideo15)
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
latent_format = HunyuanVideo
else:
latent_format = None # lighttaew2_1 doesn't need scaling
self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
self.process_input = self.process_output = lambda image: image
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
else:
logger.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -620,6 +694,8 @@ class VAE:
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
# todo: why is this being called here? for what side effects exactly?
self.model_size()
def clone(self):
n = VAE(no_init=True)
@ -644,6 +720,15 @@ class VAE:
n.patcher = self.patcher.clone()
return n
def model_size(self):
if self.size is not None:
return self.size
self.size = module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@ -704,6 +789,7 @@ class VAE:
return samples
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
extra_channel_size = 0
if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
out_channels = self.latent_channels
@ -730,6 +816,8 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -960,6 +1048,10 @@ class CLIPType(Enum):
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
@dataclasses.dataclass
@ -975,8 +1067,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
model_options = dict()
clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, textmodel_json_config=textmodel_json_config)
sd, metadata = utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None:
sd, metadata = utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
class TEModel(Enum):
@ -993,6 +1088,10 @@ class TEModel(Enum):
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
def detect_te_model(sd):
@ -1026,6 +1125,18 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
return TEModel.LLAMA3_8
return None
@ -1077,7 +1188,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.clip = hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@ -1101,7 +1212,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
else: # CLIPType.MOCHI
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
@ -1130,7 +1241,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = omnigen2.te(**llama_detect(clip_data))
@ -1142,13 +1253,23 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = qwen_image.QwenImageTokenizer
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
clip_target.clip = flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = z_image.ZImageTokenizer
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.clip = hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.tokenizer = hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@ -1188,6 +1309,15 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = hunyuan_video.HunyuanVideo15Tokenizer
elif clip_type == CLIPType.KANDINSKY5:
clip_target.clip = kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = kandinsky5.Kandinsky5Tokenizer
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = kandinsky5.Kandinsky5TokenizerImage
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -1203,14 +1333,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += utils.calculate_parameters(c)
tokenizer_data, model_options = long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
logger.warning("clip missing: {}".format(m))
if len(u) > 0:
logger.debug("clip unexpected: {}".format(u))
clip = CLIP(clip_target, textmodel_json_config=textmodel_json_config, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
return clip
@ -1285,6 +1408,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logger.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@ -1294,16 +1421,21 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
if model_config.quant_config is not None:
weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
if custom_operations is not None:
model_config.custom_operations = custom_operations
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@ -1321,22 +1453,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
scaled_fp8_list = []
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
if k.endswith(".scaled_fp8"):
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
for pref in scaled_fp8_list:
quant_sd, qmetadata = utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logger.warning("clip missing: {}".format(m))
else:
logger.debug("clip missing: {}".format(m))
if len(u) > 0:
logger.debug("clip unexpected {}:".format(u))
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
else:
logger.warning(f"no CLIP/text encoder weights in checkpoint {ckpt_path}, the text encoder model will not be loaded.")
@ -1385,6 +1528,9 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
if len(temp_sd) > 0:
sd = temp_sd
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = utils.convert_old_quants(sd, "", metadata=metadata)
parameters = utils.calculate_parameters(sd)
weight_dtype = utils.weight_dtype(sd)
load_device = model_management.get_torch_device()
@ -1414,7 +1560,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:
if model_config.quant_config is not None:
weight_dtype = None
if dtype is None:
@ -1422,9 +1568,15 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if custom_operations is not None:
model_config.custom_operations = custom_operations
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@ -1437,7 +1589,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, ckpt_name=os.path.basename(ckpt_path))
def load_diffusion_model(unet_path, model_options: dict = None):
def load_diffusion_model(unet_path, model_options=None):
if model_options is None:
model_options = {}
sd, metadata = utils.load_torch_file(unet_path, return_metadata=True)
@ -1468,6 +1620,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)

View File

@ -28,6 +28,7 @@ except ImportError:
logger = logging.getLogger(__name__)
def gen_empty_tokens(special_tokens, length):
start_token = special_tokens.get("start", None)
end_token = special_tokens.get("end", None)
@ -132,19 +133,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
quant_config = model_options.get("quantization_metadata", None)
if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
if quant_config is not None:
operations = ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
logging.info("Using MixedPrecisionOps for text encoder")
else:
operations = ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
@ -162,6 +161,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@ -178,7 +178,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all":
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
@ -190,6 +191,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@ -273,14 +275,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
if self.execution_device is None:
device = self.transformer.get_input_embeddings().weight.device
else:
device = self.execution_device
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
if self.layer == "all":
if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all"
else:
intermediate_output = self.layer_idx
@ -478,6 +486,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
valid_file = None
for embed_dir in embedding_directory:
# todo: improve this, so that it is more compatible between linux and windows
embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name))
embed_dir = os.path.abspath(embed_dir)
try:
@ -546,7 +555,7 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
class SDTokenizer:
def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=None, tokenizer_args=None):
def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data=None, tokenizer_args=None):
if tokenizer_data is None:
tokenizer_data = dict()
if tokenizer_args is None:
@ -568,6 +577,7 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None
self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@ -635,6 +645,13 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text: str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
@ -720,7 +737,7 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
self.pad_tokens(batch, remaining_length)
# start new batch
batch = []
if self.start_token is not None:
@ -734,11 +751,11 @@ class SDTokenizer:
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
self.pad_tokens(batch, min_padding)
if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
self.pad_tokens(batch, self.max_length - len(batch))
if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
self.pad_tokens(batch, min_length - len(batch))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens]
@ -756,7 +773,7 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer")
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data: dict=None, clip_name="l", tokenizer=SDTokenizer, name=None):
def __init__(self, embedding_directory=None, tokenizer_data: dict = None, clip_name="l", tokenizer=SDTokenizer, name=None):
if tokenizer_data is None:
tokenizer_data = {}
if name is not None:
@ -792,11 +809,12 @@ class SD1Tokenizer:
def state_dict(self):
return getattr(self, self.clip).state_dict()
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
if model_options is None:
model_options = {}
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
class SD1ClipModel(torch.nn.Module):

View File

@ -27,6 +27,8 @@ from .text_encoders import sd3_clip
from .text_encoders import wan
from .text_encoders import qwen_image
from .text_encoders import hunyuan_image
from .text_encoders import kandinsky5
from .text_encoders import z_image
class SD15(supported_models_base.BASE):
@ -798,6 +800,40 @@ class FluxSchnell(Flux):
return out
class Flux2(Flux):
unet_config = {
"image_model": "flux2",
}
sampling_settings = {
"shift": 2.02,
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict=None):
if state_dict is None:
state_dict = {}
return None # TODO
# pref = self.text_encoder_key_prefix[0]
# t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
# return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@ -1039,7 +1075,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0,
}
memory_usage_factor = 1.2
memory_usage_factor = 1.4
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1061,6 +1097,27 @@ class Lumina2(supported_models_base.BASE):
return supported_models_base.ClipTarget(lumina2.LuminaTokenizer, lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(z_image.ZImageTokenizer, z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -1483,6 +1540,108 @@ class HunyuanImage21Refiner(HunyuanVideo):
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
class HunyuanVideo15(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
}
sampling_settings = {
"shift": 7.0,
}
memory_usage_factor = 4.0 # TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideo15Tokenizer, hunyuan_image.te(**hunyuan_detect))
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
"in_channels": 98,
}
sampling_settings = {
"shift": 2.0,
}
memory_usage_factor = 4.0 # TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideo15Tokenizer, hunyuan_image.te(**hunyuan_detect))
class Kandinsky5(supported_models_base.BASE):
unet_config = {
"image_model": "kandinsky5",
}
sampling_settings = {
"shift": 10.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 # TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(kandinsky5.Kandinsky5Tokenizer, kandinsky5.te(**hunyuan_detect))
class Kandinsky5Image(Kandinsky5):
unet_config = {
"image_model": "kandinsky5",
"model_dim": 2560,
"visual_embed_dim": 64,
}
sampling_settings = {
"shift": 3.0,
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.1 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(kandinsky5.Kandinsky5TokenizerImage, kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]

View File

@ -18,10 +18,10 @@
from typing import Optional
import torch
import logging
from . import model_base
from . import utils
from . import latent_formats
from .ops import Operations
class ClipTarget:
@ -30,6 +30,7 @@ class ClipTarget:
self.tokenizer = tokenizer
self.params = {}
class BASE:
unet_config = {}
unet_extra_config = {
@ -51,8 +52,8 @@ class BASE:
memory_usage_factor = 2.0
manual_cast_dtype: Optional[torch.dtype] = None
custom_operations: Optional[Operations] = None
scaled_fp8: Optional[torch.dtype] = None
custom_operations: Optional[torch.dtype] = None
quant_config = None # quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
@ -120,3 +121,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
def __getattr__(self, name):
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
return None

171
comfy/taesd/taehv.py Normal file
View File

@ -0,0 +1,171 @@
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from collections import namedtuple, deque
import comfy.ops
operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
def conv(n_in, n_out, **kwargs):
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
class Clamp(nn.Module):
def forward(self, x):
return torch.tanh(x / 3) * 3
class MemBlock(nn.Module):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
class TPool(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
return self.conv(x.reshape(-1, self.stride * C, H, W))
class TGrow(nn.Module):
def __init__(self, n_f, stride):
super().__init__()
self.stride = stride
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
def forward(self, x):
_NT, C, H, W = x.shape
x = self.conv(x)
return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B*T, C, H, W)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
BT, C, H, W = x.shape
T = BT // B
_x = x.reshape(B, T, C, H, W)
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
x = b(x, mem)
else:
x = b(x)
BT, C, H, W = x.shape
T = BT // B
x = x.view(B, T, C, H, W)
else:
out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.popleft()
if i == 0:
progress_bar.update(1)
if i == len(model):
out.append(xt)
del xt
else:
b = model[i]
if isinstance(b, MemBlock):
if mem[i] is None:
xt_new = b(xt, xt * 0)
mem[i] = xt.detach().clone()
else:
xt_new = b(xt, mem[i])
mem[i] = xt.detach().clone()
del xt
work_queue.appendleft(TWorkItem(xt_new, i+1))
elif isinstance(b, TPool):
if mem[i] is None:
mem[i] = []
mem[i].append(xt.detach().clone())
if len(mem[i]) == b.stride:
B, C, H, W = xt.shape
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
mem[i] = []
work_queue.appendleft(TWorkItem(xt, i+1))
elif isinstance(b, TGrow):
xt = b(xt)
NT, C, H, W = xt.shape
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
work_queue.appendleft(TWorkItem(xt_next, i+1))
del xt
else:
xt = b(xt)
work_queue.appendleft(TWorkItem(xt, i+1))
progress_bar.close()
x = torch.stack(out, 1)
return x
class TAEHV(nn.Module):
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
super().__init__()
self.image_channels = 3
self.patch_size = 1
self.latent_channels = latent_channels
self.parallel = parallel
self.latent_format = latent_format
self.show_progress_bar = show_progress_bar
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
self.patch_size = 2
if self.latent_channels == 32: # HunyuanVideo1.5
act_func = nn.LeakyReLU(0.2, inplace=True)
else: # HunyuanVideo, Wan 2.1
act_func = nn.ReLU(inplace=True)
self.encoder = nn.Sequential(
conv(self.image_channels*self.patch_size**2, 64), act_func,
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
self.decoder = nn.Sequential(
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
)
@property
def show_progress_bar(self):
return self._show_progress_bar
@show_progress_bar.setter
def show_progress_bar(self, value):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
n_pad = 4 - x.shape[1] % 4
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -11,10 +11,10 @@ class T5XXLModel(sd1_clip.SDClipModel):
if model_options is None:
model_options = {}
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_old_config_xxl.json", package=__package__)
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5xxl_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
@ -43,14 +43,14 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
def te(dtype_t5=None, t5_quantization_metadata=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -1,12 +1,15 @@
import copy
import torch
from transformers import T5TokenizerFast
from transformers import T5TokenizerFast, LlamaTokenizerFast
from .sd3_clip import T5XXLModel
from .. import sd1_clip, model_management
from ..component_model import files
import json
import base64
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
@ -73,14 +76,127 @@ class FluxClipModel(torch.nn.Module):
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
def load_mistral_tokenizer(data):
if torch.is_tensor(data):
data = data.numpy().tobytes()
# we just have to use the latest transformers
from transformers.integrations.mistral import MistralConverter
mistral_vocab = json.loads(data)
special_tokens = {}
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
if r >= max_vocab:
continue
vocab[base64.b64decode(w["token_bytes"])] = r
for w in mistral_vocab["special_tokens"]:
if "token_bytes" in w:
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
else:
special_tokens[w["token_str"]] = w["rank"]
all_special = []
for v in special_tokens:
all_special.append(v)
special_tokens.update(vocab)
vocab = special_tokens
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
class MistralTokenizerClass:
@staticmethod
def from_pretrained(path, **kwargs):
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=None, layer_idx=None, dtype=None, attention_mask=True, model_options={}):
if layer is None:
layer = [10, 20, 30]
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40:
textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_

View File

@ -33,14 +33,14 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
tokenizer_data = {}
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -9,6 +9,7 @@ from ..model_management import intermediate_device, pick_weight_dtype
logger = logging.getLogger(__name__)
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
@ -148,17 +149,17 @@ class HiDreamTEModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -14,7 +14,7 @@ class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("byt5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, embedding_directory=None, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class HunyuanImageTokenizer(QwenImageTokenizer):
@ -52,10 +52,10 @@ class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None):
if model_options is None:
model_options = {}
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -108,14 +108,14 @@ class HunyuanImageTEModel(QwenImageTEModel):
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)

View File

@ -2,10 +2,12 @@ import torch
import numbers
from transformers import LlamaTokenizerFast
from .hunyuan_image import HunyuanImageTokenizer
from .llama import Llama2
from .. import sd1_clip
from ..component_model import files
from ..model_management import pick_weight_dtype
from ..utils import detect_layer_quantization
def llama_detect(state_dict, prefix=""):
@ -14,9 +16,9 @@ def llama_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
quant = detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["llama_quantization_metadata"] = quant
return out
@ -35,10 +37,10 @@ class LLAMAModel(sd1_clip.SDClipModel):
special_tokens = {"start": 128000, "pad": 128258}
if model_options is None:
model_options = {}
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
textmodel_json_config = textmodel_json_config or {}
vocab_size = model_options.get("vocab_size", None)
@ -83,6 +85,15 @@ class HunyuanVideoTokenizer:
return {}
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options=None):
super().__init__()
@ -161,14 +172,14 @@ class HunyuanVideoClipModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_

View File

@ -0,0 +1,68 @@
from comfy import sd1_clip
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Kandinsky5TEModel(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
return cond, l_pooled, extra
def set_clip_options(self, options):
super().set_clip_options(options)
self.clip_l.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
self.clip_l.reset_clip_options()
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return super().load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Kandinsky5TEModel_

View File

@ -33,6 +33,30 @@ class Llama2Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Mistral3Small24BConfig:
vocab_size: int = 131072
hidden_size: int = 5120
intermediate_size: int = 32768
num_hidden_layers: int = 40
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
@ -55,6 +79,53 @@ class Qwen25_3BConfig:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Ovis25_2BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
@ -77,6 +148,7 @@ class Qwen25_7BVLI_Config:
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass
@ -100,6 +172,7 @@ class Gemma2_2B_Config:
k_norm = None
sliding_attention = None
rope_scale = None
final_norm: bool = True
@dataclass
@ -123,6 +196,7 @@ class Gemma3_4B_Config:
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
final_norm: bool = True
class RMSNorm(nn.Module):
@ -375,7 +449,12 @@ class Llama2_(nn.Module):
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@ -411,8 +490,12 @@ class Llama2_(nn.Module):
intermediate = None
all_intermediate = None
only_layers = None
if intermediate_output is not None:
if intermediate_output == "all":
if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
@ -420,7 +503,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@ -430,14 +514,17 @@ class Llama2_(nn.Module):
if i == intermediate_output:
intermediate = x.clone()
x = self.norm(x)
if self.norm is not None:
x = self.norm(x)
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate:
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
return x, intermediate
@ -466,6 +553,16 @@ class Llama2(BaseLlama, torch.nn.Module):
self.dtype = dtype
class Mistral3Small24B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Mistral3Small24BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@ -476,6 +573,26 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Ovis25_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ovis25_2BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -8,29 +8,35 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_directory=None, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_directory=None, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
@ -38,10 +44,12 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
textmodel_json_config = textmodel_json_config or {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None, name="gemma2_2b", clip_model=Gemma2_2BModel):
if model_options is None:
@ -49,21 +57,22 @@ class LuminaModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
model = None
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_

View File

@ -1,8 +1,7 @@
from transformers import Qwen2Tokenizer
from .. import sd1_clip
from .llama import Qwen25_3B
import os
from .llama import Qwen25_3B
from .. import sd1_clip
from ..component_model import files
@ -10,8 +9,8 @@ class Qwen25_3BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_directory=embedding_directory, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
@ -21,20 +20,20 @@ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer)
self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
class Qwen25_3BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
textmodel_json_config = textmodel_json_config or {}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Omnigen2Model(sd1_clip.SD1ClipModel):
@ -44,15 +43,16 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Omnigen2TEModel_

View File

@ -0,0 +1,66 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
import numbers
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
class OvisTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ovis25_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
class OvisTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen3_2b"][0]
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 4004 and count_im_start < 1:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 1):
if tok_pairs[template_end + 1][0] == 25:
template_end += 1
out = out[:, template_end:]
return out, pooled, {}
def te(dtype_llama=None, llama_quantization_metadata=None):
class OvisTEModel_(OvisTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@ -35,19 +35,20 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -179,36 +179,36 @@
"special": false
},
"151665": {
"content": "<|img|>",
"content": "<tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151666": {
"content": "<|endofimg|>",
"content": "</tool_response>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151667": {
"content": "<|meta|>",
"content": "<think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
},
"151668": {
"content": "<|endofmeta|>",
"content": "</think>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
"special": false
}
},
"additional_special_tokens": [

View File

@ -12,7 +12,7 @@ class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
@ -23,12 +23,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
skip_template = False
if text.startswith('<|im_start|>'):
skip_template = True
if text.startswith('<|start_header_id|>'):
skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
@ -94,14 +96,14 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
return out, pooled, extra
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -1,5 +1,6 @@
import copy
import logging
import comfy.utils
import torch
from transformers import T5TokenizerFast
@ -17,10 +18,10 @@ class T5XXLModel(sd1_clip.SDClipModel):
if model_options is None:
model_options = {}
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -32,9 +33,9 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["t5_quantization_metadata"] = quant
return out
@ -175,14 +176,14 @@ class SD3ClipModel(torch.nn.Module):
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@ -37,14 +37,14 @@ class WanT5Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
def te(dtype_t5=None, t5_quantization_metadata=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options["quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -0,0 +1,45 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -79,6 +79,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
except (ImportError, ModuleNotFoundError):
from numpy import generic as scalar
from numpy import dtype
try:
from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error
except (ImportError, ModuleNotFoundError):
@ -779,6 +780,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def z_image_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("n_layers", 0)
hidden_size = mmdit_config.get("dim", 0)
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
key_map = {}
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
for end in ("weight", "bias"):
k = "{}.attention.".format(prefix_from)
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attention.norm_q.weight": "attention.q_norm.weight",
"attention.norm_k.weight": "attention.k_norm.weight",
"attention.to_out.0.weight": "attention.out.weight",
"attention.to_out.0.bias": "attention.out.bias",
"attention_norm1.weight": "attention_norm1.weight",
"attention_norm2.weight": "attention_norm2.weight",
"feed_forward.w1.weight": "feed_forward.w1.weight",
"feed_forward.w2.weight": "feed_forward.w2.weight",
"feed_forward.w3.weight": "feed_forward.w3.weight",
"ffn_norm1.weight": "ffn_norm1.weight",
"ffn_norm2.weight": "ffn_norm2.weight",
}
if has_adaln:
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
for k, v in block_map.items():
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
for i in range(n_layers):
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
for i in range(n_context_refiner):
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
for i in range(n_noise_refiner):
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
MAP_BASIC = [
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
("x_embedder.weight", "all_x_embedder.2-1.weight"),
("x_embedder.bias", "all_x_embedder.2-1.bias"),
("x_pad_token", "x_pad_token"),
("cap_embedder.0.weight", "cap_embedder.0.weight"),
("cap_embedder.1.weight", "cap_embedder.1.weight"),
("cap_embedder.1.bias", "cap_embedder.1.bias"),
("cap_pad_token", "cap_pad_token"),
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
]
for c, diffusers in MAP_BASIC:
key_map[diffusers] = "{}{}".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
@ -1379,3 +1446,94 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1
)
return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors
def detect_layer_quantization(state_dict, prefix):
for k in state_dict:
if k.startswith(prefix) and k.endswith(".comfy_quant"):
logging.info("Found quantization metadata version 1")
return {"mixed_ops": True}
return None
def convert_old_quants(state_dict, model_prefix="", metadata={}):
if metadata is None:
metadata = {}
quant_metadata = None
if "_quantization_metadata" not in metadata:
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
continue
out_sd[k_out] = w
state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
return state_dict, metadata

View File

@ -197,6 +197,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None:
weight = weight_decompose(
dora_scale,

View File

@ -13,6 +13,7 @@ from comfy.cli_args import args
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
}

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args
from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker:
@ -225,11 +225,18 @@ class AsyncToSyncConverter:
self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# Get all annotations from the class hierarchy and resolve string annotations
try:
# get_type_hints resolves string annotations to actual type objects
# This handles classes using 'from __future__ import annotations'
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations if get_type_hints fails
# (e.g., for undefined forward references)
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items():
@ -630,15 +637,19 @@ class AsyncToSyncConverter:
"""Extract class attributes that are classes themselves."""
class_attributes = []
# Get resolved type hints to handle string annotations
try:
type_hints = get_type_hints(async_class)
except Exception:
type_hints = {}
# Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr))
elif (
hasattr(async_class, "__annotations__")
and name in async_class.__annotations__
):
annotation = async_class.__annotations__[name]
elif name in type_hints:
# Use resolved type hint instead of raw annotation
annotation = type_hints[name]
if isinstance(annotation, type):
class_attributes.append((name, annotation))
@ -913,11 +924,15 @@ class AsyncToSyncConverter:
attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# Resolve string annotations to actual types
try:
all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes:

View File

@ -5,11 +5,11 @@ from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
from . import _io as io
from . import _ui as ui
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources neeeded by the extension.
This should be used to initialize any global resources needed by the extension.
"""
@abstractmethod
@ -104,6 +104,8 @@ class Types:
VideoCodec = VideoCodec
VideoContainer = VideoContainer
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
ComfyAPI = ComfyAPI_latest

View File

@ -1,9 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""
@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import AudioInput, VideoInput
from .._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
from .._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str:
"""
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream]
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput):
"""
Class representing video input from tensors.
@ -264,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():

View File

@ -4,7 +4,8 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import asdict, dataclass
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@ -25,8 +26,9 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal
from ._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
@ -149,6 +151,9 @@ class _IO_V3:
def __init__(self):
pass
def validate(self):
pass
@property
def io_type(self):
return self.Parent.io_type
@ -181,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self):
return _StringIOType(self.io_type)
def get_all(self) -> list[Input]:
return [self]
class WidgetInput(Input):
'''
Base class for a V3 Input with widget.
@ -560,6 +568,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
time_dim_replace: NotRequired[torch.Tensor]
'''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@ -628,6 +638,10 @@ class UpscaleModel(ComfyTypeIO):
if TYPE_CHECKING:
Type = ImageModelDescriptor
@comfytype(io_type="LATENT_UPSCALE_MODEL")
class LatentUpscaleModel(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO")
class Audio(ComfyTypeIO):
class AudioDict(TypedDict):
@ -656,11 +670,11 @@ class LossMap(ComfyTypeIO):
@comfytype(io_type="VOXEL")
class Voxel(ComfyTypeIO):
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
Type = VOXEL
@comfytype(io_type="MESH")
class Mesh(ComfyTypeIO):
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
Type = MESH
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
@ -809,13 +823,61 @@ class MultiType:
else:
return super().as_dict()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(Input):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
@abstractmethod
def get_dynamic(self) -> list[Input]:
...
return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC):
'''
@ -825,99 +887,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
@abstractmethod
def get_dynamic(self) -> list[Output]:
...
return []
@comfytype(io_type="COMFY_AUTOGROW_V3")
class AutogrowDynamic(ComfyTypeI):
Type = list[Any]
class Input(DynamicInput):
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template_input = template_input
if min is not None:
assert(min >= 1)
if max is not None:
assert(max >= 1)
class Autogrow(ComfyTypeI):
Type = dict[str, Any]
_MaxNames = 100 # NOTE: max 100 names for sanity
class _AutogrowTemplate:
def __init__(self, input: Input):
# dynamic inputs are not allowed as the template input
assert(not isinstance(input, DynamicInput))
self.input = copy.copy(input)
if isinstance(self.input, WidgetInput):
self.input.force_input = True
self.names: list[str] = []
self.cached_inputs = {}
def _create_input(self, input: Input, name: str):
new_input = copy.copy(self.input)
new_input.id = name
return new_input
def _create_cached_inputs(self):
for name in self.names:
self.cached_inputs[name] = self._create_input(self.input, name)
def get_all(self) -> list[Input]:
return list(self.cached_inputs.values())
def as_dict(self):
return prune_dict({
"input": create_input_dict_v1([self.input]),
})
def validate(self):
self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input)
self.prefix = prefix
assert(min >= 0)
assert(max >= 1)
assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"prefix": self.prefix,
"min": self.min,
"max": self.max,
})
class TemplateNames(_AutogrowTemplate):
def __init__(self, input: Input, names: list[str], min: int=1):
super().__init__(input)
self.names = names[:Autogrow._MaxNames]
assert(min >= 0)
self.min = min
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"names": self.names,
"min": self.min,
})
class Input(DynamicInput):
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
def get_dynamic(self) -> list[Input]:
curr_count = 1
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
return self.template.get_all()
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
class ComboDynamic(ComfyTypeI):
class Input(DynamicInput):
def __init__(self, id: str):
pass
def get_all(self) -> list[Input]:
return [self] + self.template.get_all()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
self.template_id = template_id
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
def validate(self):
self.template.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
curr_prefix = f"{curr_prefix}{self.id}."
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
for inner_dict in d.values():
if self.id in inner_dict:
del inner_dict[self.id]
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI):
Type = dict[str, Any]
class Option:
def __init__(self, key: str, inputs: list[Input]):
self.key = key
self.inputs = inputs
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": "".join(t.io_type for t in self.allowed_types),
"key": self.key,
"inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
def __init__(self, id: str, template: MatchType.Template,
def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
return [self]
return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
"options": [o.as_dict() for o in self.options],
})
class Output(DynamicOutput):
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def validate(self):
# make sure all nested inputs are validated
for option in self.options:
for input in option.inputs:
input.validate()
def get_dynamic(self) -> list[Output]:
return [self]
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI):
Type = dict[str, Any]
class Input(DynamicInput):
def __init__(self, slot: Input, inputs: list[Input],
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(not isinstance(slot, DynamicInput))
self.slot = copy.copy(slot)
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
optional = True
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
self.inputs = inputs
self.force_input = None
# force widget inputs to have no widgets, otherwise this would be awkward
if isinstance(self.slot, WidgetInput):
self.force_input = True
self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
"slotType": str(self.slot.get_io_type()),
"inputs": create_input_dict_v1(self.inputs),
"forceInput": self.force_input,
})
def validate(self):
self.slot.validate()
for input in self.inputs:
input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
dynamic = d.setdefault("dynamic_paths", {})
if self is not None:
dynamic[self.id] = f"{curr_prefix}{self.id}"
for i in inputs:
if not isinstance(i, DynamicInput):
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
class V3Data(TypedDict):
hidden_inputs: dict[str, Any]
dynamic_paths: dict[str, Any]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@ -979,6 +1165,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@ -1015,9 +1202,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[Input]=None
outputs: list[Output]=None
hidden: list[Hidden]=None
inputs: list[Input] = field(default_factory=list)
outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
@ -1057,7 +1244,11 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
nested_inputs: list[Input] = []
if self.inputs is not None:
for input in self.inputs:
nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
@ -1073,6 +1264,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
# validate inputs and outputs
if self.inputs is not None:
for input in self.inputs:
input.validate()
if self.outputs is not None:
for output in self.outputs:
output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1098,19 +1296,10 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1:
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs
input = {
"required": {}
}
if self.inputs:
for i in self.inputs:
if isinstance(i, DynamicInput):
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1119,12 +1308,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1(
input=input,
@ -1133,6 +1334,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@ -1178,16 +1380,57 @@ class Schema:
return info
def add_to_dict_v1(i: Input, input: dict):
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
input = {
"required": {}
}
add_to_input_dict_v1(input, inputs, live_inputs)
return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
if dynamic_dict is None:
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
if paths is None:
return values
values = values.copy()
result = {}
for key, path in paths.items():
parts = path.split(".")
current = result
for i, p in enumerate(parts):
is_last = (i == len(parts) - 1)
if is_last:
current[p] = values.pop(key, None)
else:
current = current.setdefault(p, {})
values.update(result)
return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@ -1307,12 +1550,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone
@final
@ -1429,14 +1672,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls)
info = schema.get_v1_info(cls, live_inputs)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
return input, schema
v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input
@final
@ -1509,7 +1756,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool:
def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@ -1625,6 +1872,7 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
"LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@ -1648,6 +1896,10 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
# Dynamic Types
"MatchType",
# "DynamicCombo",
# "Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@ -1658,4 +1910,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
]

View File

@ -0,0 +1 @@
from ._io import * # noqa: F403

View File

@ -4,6 +4,7 @@ import json
import logging
import os
import random
import uuid
from io import BytesIO
from typing import Type
@ -21,7 +22,7 @@ from PIL.PngImagePlugin import PngInfo
# used for image preview
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
from ._io import ComfyNode, FolderType, Image, _UIOutput
logger = logging.getLogger(__name__)
@ -324,9 +325,10 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@ -338,7 +340,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@ -347,12 +349,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
@ -442,9 +444,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
self.bg_image_path = None
bg_image = kwargs.get("bg_image", None)
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = PILImage.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
self.bg_image_path = f"temp/{filename}"
def as_dict(self):
return {"result": [self.model_file, self.camera_info]}
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput):

View File

@ -0,0 +1 @@
from ._ui import * # noqa: F403

View File

@ -1,8 +1,11 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
__all__ = [
# Utility Types
"VideoContainer",
"VideoCodec",
"VideoComponents",
"VOXEL",
"MESH",
]

View File

@ -0,0 +1,12 @@
import torch
class VOXEL:
def __init__(self, data: torch.Tensor):
self.data = data
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from comfy_api.latest._input import ImageInput, AudioInput
from .._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -42,4 +42,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"IO",
"ui",
"UI",
]

View File

@ -1,718 +0,0 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
import os
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from comfy.cmd.server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
import uuid
from io import BytesIO
import av
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def get_size(path_or_object: Union[str, io.BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")

View File

@ -1,17 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@ -1,57 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
# )
class Flux2ProGenerateRequest(BaseModel):
prompt: str = Field(...)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
seed: int | None = Field(None)
prompt_upsampling: bool | None = Field(None)
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.')
polling_url: str = Field(..., description='URL to poll for the generation result.')
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
cost: float | None = Field(None, description="Price in cents")
class BFLStatus(str, Enum):
@ -160,15 +146,8 @@ class BFLStatus(str, Enum):
error = "Error"
class BFLFluxProStatusResponse(BaseModel):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(
None, description="The result of the task (null if not completed)."
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)

Some files were not shown because too many files have changed in this diff Show More