mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Merge branch 'master' into 20260427a_ltxv_addguide_causal_fix
This commit is contained in:
commit
a795eff4d5
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
|
||||
pause
|
||||
31
.github/workflows/openapi-lint.yml
vendored
Normal file
31
.github/workflows/openapi-lint.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: OpenAPI Lint
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
paths:
|
||||
- 'openapi.yaml'
|
||||
- '.spectral.yaml'
|
||||
- '.github/workflows/openapi-lint.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
spectral:
|
||||
name: Run Spectral
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '20'
|
||||
|
||||
- name: Install Spectral
|
||||
run: npm install -g @stoplight/spectral-cli@6
|
||||
|
||||
- name: Lint openapi.yaml
|
||||
run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -23,3 +23,4 @@ web_custom_versions/
|
||||
.DS_Store
|
||||
filtered-openapi.yaml
|
||||
uv.lock
|
||||
.comfy_environment
|
||||
|
||||
91
.spectral.yaml
Normal file
91
.spectral.yaml
Normal file
@ -0,0 +1,91 @@
|
||||
extends:
|
||||
- spectral:oas
|
||||
|
||||
# Severity levels: error, warn, info, hint, off
|
||||
# Rules from the built-in "spectral:oas" ruleset are active by default.
|
||||
# Below we tune severity and add custom rules for our conventions.
|
||||
#
|
||||
# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
|
||||
# organization are linted against a single consistent standard.
|
||||
|
||||
rules:
|
||||
# -----------------------------------------------------------------------
|
||||
# Built-in rule severity overrides
|
||||
# -----------------------------------------------------------------------
|
||||
operation-operationId: error
|
||||
operation-description: warn
|
||||
operation-tag-defined: error
|
||||
info-contact: off
|
||||
info-description: warn
|
||||
no-eval-in-markdown: error
|
||||
no-$ref-siblings: error
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: naming conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Property names should be snake_case
|
||||
property-name-snake-case:
|
||||
description: Property names must be snake_case
|
||||
severity: warn
|
||||
given: "$.components.schemas.*.properties[*]~"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
|
||||
|
||||
# Operation IDs should be camelCase
|
||||
operation-id-camel-case:
|
||||
description: Operation IDs must be camelCase
|
||||
severity: warn
|
||||
given: "$.paths.*.*.operationId"
|
||||
then:
|
||||
function: pattern
|
||||
functionOptions:
|
||||
match: "^[a-z][a-zA-Z0-9]*$"
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: response conventions
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Error responses (4xx, 5xx) should use a consistent shape
|
||||
error-response-schema:
|
||||
description: Error responses should reference a standard error schema
|
||||
severity: hint
|
||||
given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
|
||||
then:
|
||||
field: "$ref"
|
||||
function: truthy
|
||||
|
||||
# All 2xx responses with JSON body should have a schema
|
||||
response-schema-defined:
|
||||
description: Success responses with JSON content should define a schema
|
||||
severity: warn
|
||||
given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
|
||||
then:
|
||||
field: schema
|
||||
function: truthy
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Custom rules: best practices
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# Path parameters must have a description
|
||||
path-param-description:
|
||||
description: Path parameters should have a description
|
||||
severity: warn
|
||||
given:
|
||||
- "$.paths.*.parameters[?(@.in == 'path')]"
|
||||
- "$.paths.*.*.parameters[?(@.in == 'path')]"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
|
||||
# Schemas should have a description
|
||||
schema-description:
|
||||
description: Component schemas should have a description
|
||||
severity: hint
|
||||
given: "$.components.schemas.*"
|
||||
then:
|
||||
field: description
|
||||
function: truthy
|
||||
@ -1,2 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
|
||||
|
||||
23
README.md
23
README.md
@ -1,7 +1,7 @@
|
||||
<div align="center">
|
||||
|
||||
# ComfyUI
|
||||
**The most powerful and modular visual AI engine and application.**
|
||||
**The most powerful and modular AI engine for content creation.**
|
||||
|
||||
|
||||
[![Website][website-shield]][website-url]
|
||||
@ -31,10 +31,16 @@
|
||||
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
|
||||
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
|
||||
|
||||

|
||||
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
|
||||
<br>
|
||||
</div>
|
||||
|
||||
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
|
||||
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
|
||||
- ComfyUI natively supports the latest open-source state of the art models.
|
||||
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
|
||||
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
|
||||
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
|
||||
- It integrates seamlessly into production pipelines with our API endpoints.
|
||||
|
||||
## Get Started
|
||||
|
||||
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Ernie Image
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||
|
||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||
- Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
|
||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||
- Minor versions will be used for releases off the master branch.
|
||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||
@ -193,13 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
#### All Official Portable Downloads:
|
||||
|
||||
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
|
||||
|
||||
[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
|
||||
@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||
return {
|
||||
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
|
||||
"size": os.path.getsize(path),
|
||||
"modified": os.path.getmtime(path),
|
||||
"created": os.path.getctime(path)
|
||||
"modified": int(os.path.getmtime(path) * 1000),
|
||||
"created": int(os.path.getctime(path) * 1000),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -90,8 +90,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
NoPreviews = "none"
|
||||
@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
indices = self.index_list
|
||||
anchor_idx = getattr(self, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
indices = [anchor_idx] + list(indices)
|
||||
idx = tuple([slice(None)] * dim + [indices])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
|
||||
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||
if temporal_offset > 0:
|
||||
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
|
||||
skip_count = temporal_offset - 1
|
||||
else:
|
||||
skip_count = temporal_offset
|
||||
|
||||
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
|
||||
indices = [i for i in indices if 0 <= i]
|
||||
else:
|
||||
indices = list(window.index_list)
|
||||
@ -150,7 +161,8 @@ class ContextFuseMethod:
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
if 0 <= anchor_idx < x_in.size(self.dim):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
34
comfy/deploy_environment.py
Normal file
34
comfy/deploy_environment.py
Normal file
@ -0,0 +1,34 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_DEPLOY_ENV = "local-git"
|
||||
_ENV_FILENAME = ".comfy_environment"
|
||||
|
||||
# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
|
||||
# We deliberately avoid `folder_paths.base_path` here because that is overridden
|
||||
# by the `--base-directory` CLI arg to a user-supplied path, whereas the
|
||||
# `.comfy_environment` marker is written by launchers/installers next to the
|
||||
# ComfyUI install itself.
|
||||
_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_deploy_environment() -> str:
|
||||
env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
|
||||
try:
|
||||
with open(env_file, encoding="utf-8") as f:
|
||||
# Cap the read so a malformed or maliciously crafted file (e.g.
|
||||
# a single huge line with no newline) can't blow up memory.
|
||||
first_line = f.readline(128).strip()
|
||||
value = "".join(c for c in first_line if 32 <= ord(c) < 127)
|
||||
if value:
|
||||
return value
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Failed to read %s: %s", env_file, e)
|
||||
|
||||
return _DEFAULT_DEPLOY_ENV
|
||||
@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||
num_frame_per_block=1):
|
||||
"""
|
||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||
|
||||
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||
|
||||
All AR-loop parameters are passed via the SamplerARVideo node, not read
|
||||
from the checkpoint or transformer_options.
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
model_options = extra_args.get("model_options", {})
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
|
||||
if x.ndim != 5:
|
||||
raise ValueError(
|
||||
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||
)
|
||||
|
||||
inner_model = model.inner_model.inner_model
|
||||
causal_model = inner_model.diffusion_model
|
||||
|
||||
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||
raise TypeError(
|
||||
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||
"does not support this interface — choose a different sampler."
|
||||
)
|
||||
|
||||
seed = extra_args.get("seed", 0)
|
||||
|
||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||
device = x.device
|
||||
model_dtype = inner_model.get_dtype()
|
||||
|
||||
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
try:
|
||||
for block_idx in trange(num_blocks, disable=disable):
|
||||
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||
fs, fe = current_start_frame, current_start_frame + bf
|
||||
noisy_input = x[:, :, fs:fe]
|
||||
|
||||
ar_state = {
|
||||
"start_frame": current_start_frame,
|
||||
"kv_caches": kv_caches,
|
||||
"crossattn_caches": crossattn_caches,
|
||||
}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
|
||||
for i in range(num_sigma_steps):
|
||||
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||
|
||||
if callback is not None:
|
||||
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
noisy_input = denoised
|
||||
else:
|
||||
sigma_next = sigmas[i + 1]
|
||||
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||
fresh_noise = torch.randn_like(denoised)
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
|
||||
step_count += 1
|
||||
|
||||
output[:, :, fs:fe] = noisy_input
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame += bf
|
||||
finally:
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
@ -9,6 +9,7 @@ class LatentFormat:
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
spacial_downscale_ratio = 8
|
||||
temporal_downscale_ratio = 1
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent * self.scale_factor
|
||||
@ -224,6 +225,7 @@ class Flux2(LatentFormat):
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
self.taesd_decoder_name = "taef2_decoder"
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
@ -234,6 +236,7 @@ class Flux2(LatentFormat):
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 6
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
@ -277,6 +280,7 @@ class LTXV(LatentFormat):
|
||||
latent_channels = 128
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 32
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors = [
|
||||
@ -420,6 +424,7 @@ class LTXAV(LTXV):
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
@ -446,6 +451,7 @@ class HunyuanVideo(LatentFormat):
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 8
|
||||
|
||||
latent_rgb_factors = [
|
||||
[ 0.1817, 0.2284, 0.2423],
|
||||
@ -471,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
|
||||
class Wan21(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
latent_rgb_factors = [
|
||||
[-0.1299, -0.1692, 0.2932],
|
||||
@ -733,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
spacial_downscale_ratio = 16
|
||||
temporal_downscale_ratio = 4
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
@ -783,3 +791,29 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
scale_factor matches the vae/config.json scaling_factor for the 2b variant.
|
||||
The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
|
||||
use a different value; see CogVideoX1_5 below.
|
||||
"""
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
temporal_downscale_ratio = 4
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
|
||||
class CogVideoX1_5(CogVideoX):
|
||||
"""Latent format for 5b-class CogVideoX checkpoints.
|
||||
|
||||
Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
|
||||
V1.5-5b family (including VOID inpainting). All of these have
|
||||
scaling_factor=0.7 in their vae/config.json. Auto-selected in
|
||||
supported_models.CogVideoX_T2V based on transformer hidden dim.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.7
|
||||
|
||||
0
comfy/ldm/cogvideo/__init__.py
Normal file
0
comfy/ldm/cogvideo/__init__.py
Normal file
573
comfy/ldm/cogvideo/model.py
Normal file
573
comfy/ldm/cogvideo/model.py
Normal file
@ -0,0 +1,573 @@
|
||||
# CogVideoX 3D Transformer - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers CogVideoXTransformer3DModel
|
||||
# Style reference: comfy/ldm/wan/model.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
|
||||
"""Returns (cos, sin) each with shape [seq_len, dim].
|
||||
|
||||
Frequencies are computed at dim//2 resolution then repeat_interleaved
|
||||
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
|
||||
"""
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
|
||||
angles = torch.outer(pos.float(), freqs.float())
|
||||
cos = angles.cos().repeat_interleave(2, dim=-1).float()
|
||||
sin = angles.sin().repeat_interleave(2, dim=-1).float()
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos_sin):
|
||||
"""Apply CogVideoX rotary embedding to query or key tensor.
|
||||
|
||||
x: [B, heads, seq_len, head_dim]
|
||||
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
|
||||
|
||||
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
|
||||
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
|
||||
"""
|
||||
cos, sin = freqs_cos_sin
|
||||
cos = cos[None, None, :, :].to(x.device)
|
||||
sin = sin[None, None, :, :].to(x.device)
|
||||
|
||||
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
|
||||
args = timesteps[:, None].float() * freqs[None] * scale
|
||||
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
if flip_sin_to_cos:
|
||||
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
|
||||
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
|
||||
|
||||
embed_dim_spatial = 2 * (embed_dim // 3)
|
||||
embed_dim_temporal = embed_dim // 3
|
||||
|
||||
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
|
||||
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
|
||||
|
||||
T, H, W = grid_t.shape
|
||||
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
|
||||
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
|
||||
T, H, W = grid_h.shape
|
||||
half_dim = embed_dim // 2
|
||||
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
return torch.cat([pos_h, pos_w], dim=-1)
|
||||
|
||||
|
||||
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
|
||||
half = embed_dim // 2
|
||||
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
|
||||
args = pos.float().reshape(-1)[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if embed_dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
|
||||
text_dim=4096, bias=True, sample_width=90, sample_height=60,
|
||||
sample_frames=49, temporal_compression_ratio=4,
|
||||
max_text_seq_length=226, spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
|
||||
use_learned_positional_embeddings=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.dim = dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
||||
|
||||
if patch_size_t is None:
|
||||
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
|
||||
|
||||
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
|
||||
|
||||
if use_positional_embeddings or use_learned_positional_embeddings:
|
||||
persistent = use_learned_positional_embeddings
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
if self.patch_size_t is not None:
|
||||
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=device,
|
||||
)
|
||||
pos_embedding = pos_embedding.reshape(-1, self.dim)
|
||||
joint_pos_embedding = pos_embedding.new_zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
|
||||
return joint_pos_embedding
|
||||
|
||||
def forward(self, text_embeds, image_embeds):
|
||||
input_dtype = text_embeds.dtype
|
||||
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
|
||||
batch_size, num_frames, channels, height, width = image_embeds.shape
|
||||
|
||||
proj_dtype = self.proj.weight.dtype
|
||||
if self.patch_size_t is None:
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3)
|
||||
image_embeds = image_embeds.flatten(1, 2)
|
||||
else:
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
||||
image_embeds = image_embeds.reshape(
|
||||
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
||||
)
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
||||
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
|
||||
|
||||
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
||||
|
||||
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
||||
text_seq_length = text_embeds.shape[1]
|
||||
num_image_patches = image_embeds.shape[1]
|
||||
|
||||
if self.use_learned_positional_embeddings:
|
||||
image_pos = self.pos_embedding[
|
||||
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
|
||||
].to(device=embeds.device, dtype=embeds.dtype)
|
||||
else:
|
||||
image_pos = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(width // self.patch_size, height // self.patch_size),
|
||||
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=embeds.device,
|
||||
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
|
||||
|
||||
# Build joint: zeros for text + sincos for image
|
||||
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
|
||||
joint_pos[:, text_seq_length:] = image_pos
|
||||
embeds = embeds + joint_pos
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb):
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class CogVideoXAdaLayerNorm(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(self.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class CogVideoXBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, head_dim, time_dim,
|
||||
eps=1e-5, ff_inner_dim=None, ff_bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Self-attention (joint text + latent)
|
||||
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Feed-forward (GELU approximate)
|
||||
inner_dim = ff_inner_dim or dim * 4
|
||||
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# Norm & modulate
|
||||
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Joint self-attention
|
||||
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
b, s, _ = qkv_input.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(qkv_input).view(b, s, n, d)
|
||||
k = self.k(qkv_input).view(b, s, n, d)
|
||||
v = self.v(qkv_input)
|
||||
|
||||
q = self.norm_q(q).view(b, s, n, d)
|
||||
k = self.norm_k(k).view(b, s, n, d)
|
||||
|
||||
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
|
||||
if image_rotary_emb is not None:
|
||||
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
|
||||
k_img = k[:, text_seq_length:].transpose(1, 2)
|
||||
q_img = apply_rotary_emb(q_img, image_rotary_emb)
|
||||
k_img = apply_rotary_emb(k_img, image_rotary_emb)
|
||||
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
|
||||
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
|
||||
|
||||
attn_out = optimized_attention(
|
||||
q.reshape(b, s, n * d),
|
||||
k.reshape(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
attn_out = self.attn_out(attn_out)
|
||||
|
||||
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
|
||||
|
||||
# Norm & modulate for FF
|
||||
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Feed-forward (GELU on concatenated text + latent)
|
||||
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(nn.Module):
|
||||
def __init__(self,
|
||||
num_attention_heads=30,
|
||||
attention_head_dim=64,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
time_embed_dim=512,
|
||||
ofs_embed_dim=None,
|
||||
text_embed_dim=4096,
|
||||
num_layers=30,
|
||||
dropout=0.0,
|
||||
attention_bias=True,
|
||||
sample_width=90,
|
||||
sample_height=60,
|
||||
sample_frames=49,
|
||||
patch_size=2,
|
||||
patch_size_t=None,
|
||||
temporal_compression_ratio=4,
|
||||
max_text_seq_length=226,
|
||||
spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0,
|
||||
use_rotary_positional_embeddings=False,
|
||||
use_learned_positional_embeddings=False,
|
||||
patch_bias=True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
dim = num_attention_heads * attention_head_dim
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
in_channels=in_channels,
|
||||
dim=dim,
|
||||
text_dim=text_embed_dim,
|
||||
bias=patch_bias,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
device=device, dtype=torch.float32, operations=operations,
|
||||
)
|
||||
|
||||
# 2. Time embedding
|
||||
self.time_proj_dim = dim
|
||||
self.time_proj_flip = flip_sin_to_cos
|
||||
self.time_proj_shift = freq_shift
|
||||
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# Optional OFS embedding (CogVideoX 1.5 I2V)
|
||||
self.ofs_proj_dim = ofs_embed_dim
|
||||
if ofs_embed_dim:
|
||||
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
self.ofs_embedding_act = nn.SiLU()
|
||||
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.ofs_embedding_linear_1 = None
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
CogVideoXBlock(
|
||||
dim=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
time_dim=time_embed_dim,
|
||||
eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
# 4. Output
|
||||
self.norm_out = CogVideoXAdaLayerNorm(
|
||||
time_dim=time_embed_dim, dim=dim, eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
if patch_size_t is None:
|
||||
output_dim = patch_size * patch_size * out_channels
|
||||
else:
|
||||
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
||||
|
||||
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
|
||||
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
# ComfyUI passes [B, C, T, H, W]
|
||||
batch_size, channels, t, h, w = x.shape
|
||||
|
||||
# Pad to patch size (temporal + spatial), same pattern as WAN
|
||||
p_t = self.patch_size_t if self.patch_size_t is not None else 1
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
|
||||
|
||||
# CogVideoX expects [B, T, C, H, W]
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
batch_size, num_frames, channels, height, width = x.shape
|
||||
|
||||
# Time embedding
|
||||
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
t_emb = t_emb.to(dtype=x.dtype)
|
||||
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
|
||||
|
||||
if self.ofs_embedding_linear_1 is not None and ofs is not None:
|
||||
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
ofs_emb = ofs_emb.to(dtype=x.dtype)
|
||||
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
|
||||
emb = emb + ofs_emb
|
||||
|
||||
# Patch embedding
|
||||
hidden_states = self.patch_embed(context, x)
|
||||
|
||||
text_seq_length = context.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# Rotary embeddings (if used)
|
||||
image_rotary_emb = None
|
||||
if self.use_rotary_positional_embeddings:
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
if self.patch_size_t is None:
|
||||
post_time = num_frames
|
||||
else:
|
||||
post_time = num_frames // self.patch_size_t
|
||||
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# Unpatchify
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
|
||||
if p_t is None:
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
else:
|
||||
output = hidden_states.reshape(
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
# Back to ComfyUI format [B, C, T, H, W] and crop padding
|
||||
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
|
||||
return output
|
||||
|
||||
def _get_rotary_emb(self, h, w, t, device):
|
||||
"""Compute CogVideoX 3D rotary positional embeddings.
|
||||
|
||||
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions
|
||||
are integer arange computed at max_size, then sliced to actual size.
|
||||
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
|
||||
scaled by spatial_interpolation_scale.
|
||||
"""
|
||||
d = self.attention_head_dim
|
||||
dim_t = d // 4
|
||||
dim_h = d // 8 * 3
|
||||
dim_w = d // 8 * 3
|
||||
|
||||
if self.patch_size_t is not None:
|
||||
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
|
||||
# Compute at max(sample_size, actual_size) then slice to actual
|
||||
base_h = self.patch_embed.sample_height // self.patch_size
|
||||
base_w = self.patch_embed.sample_width // self.patch_size
|
||||
max_h = max(base_h, h)
|
||||
max_w = max(base_w, w)
|
||||
|
||||
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
||||
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
else:
|
||||
# CogVideoX 1.0: "linspace" mode with interpolation scale
|
||||
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
|
||||
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
|
||||
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
|
||||
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
|
||||
|
||||
t_cos, t_sin = freqs_t
|
||||
h_cos, h_sin = freqs_h
|
||||
w_cos, w_sin = freqs_w
|
||||
|
||||
# Slice to actual size (for "slice" mode where grids may be larger)
|
||||
t_cos, t_sin = t_cos[:t], t_sin[:t]
|
||||
h_cos, h_sin = h_cos[:h], h_sin[:h]
|
||||
w_cos, w_sin = w_cos[:w], w_sin[:w]
|
||||
|
||||
# Broadcast and concatenate into [T*H*W, head_dim]
|
||||
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
|
||||
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
|
||||
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
|
||||
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
|
||||
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
|
||||
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
|
||||
|
||||
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
|
||||
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
|
||||
return (cos, sin)
|
||||
566
comfy/ldm/cogvideo/vae.py
Normal file
566
comfy/ldm/cogvideo/vae.py
Normal file
@ -0,0 +1,566 @@
|
||||
# CogVideoX VAE - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers AutoencoderKLCogVideoX
|
||||
# Style reference: comfy/ldm/wan/vae.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""Causal 3D convolution with temporal padding.
|
||||
|
||||
Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
|
||||
a single temporal frame and no cache, the 3D conv weight is sliced to act
|
||||
as a 2D conv, avoiding computation on zero-padded temporal dimensions.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * 3
|
||||
|
||||
time_kernel, height_kernel, width_kernel = kernel_size
|
||||
self.time_kernel_size = time_kernel
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
height_pad = (height_kernel - 1) // 2
|
||||
width_pad = (width_kernel - 1) // 2
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
|
||||
|
||||
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=stride, dilation=dilation,
|
||||
padding=(0, height_pad, width_pad),
|
||||
)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
if self.pad_mode == "replicate":
|
||||
x = F.pad(x, self.time_causal_padding, mode="replicate")
|
||||
conv_cache = None
|
||||
else:
|
||||
kernel_t = self.time_kernel_size
|
||||
if kernel_t > 1:
|
||||
if conv_cache is None and x.shape[2] == 1:
|
||||
# Fast path: single frame, no cache. All temporal padding
|
||||
# frames are copies of the input (replicate-style), so the
|
||||
# 3D conv reduces to a 2D conv with summed temporal kernel.
|
||||
w = comfy.ops.cast_to_input(self.conv.weight, x)
|
||||
b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
|
||||
w2d = w.sum(dim=2, keepdim=True)
|
||||
out = F.conv3d(x, w2d, b,
|
||||
self.conv.stride, self.conv.padding,
|
||||
self.conv.dilation, self.conv.groups)
|
||||
return out, None
|
||||
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
|
||||
x = torch.cat(cached + [x], dim=2)
|
||||
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
|
||||
|
||||
out = self.conv(x)
|
||||
return out, conv_cache
|
||||
|
||||
|
||||
def _interpolate_zq(zq, target_size):
|
||||
"""Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
|
||||
t = target_size[0]
|
||||
if t > 1 and t % 2 == 1:
|
||||
z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
|
||||
z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
|
||||
return torch.cat([z_first, z_rest], dim=2)
|
||||
return F.interpolate(zq, size=target_size)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
"""Spatially conditioned normalization."""
|
||||
def __init__(self, f_channels, zq_channels, groups=32):
|
||||
super().__init__()
|
||||
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f, zq, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if zq.shape[-3:] != f.shape[-3:]:
|
||||
zq = _interpolate_zq(zq, f.shape[-3:])
|
||||
|
||||
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
return self.norm_layer(f) * conv_y + conv_b, new_cache
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
"""3D ResNet block with optional spatial norm."""
|
||||
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
|
||||
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if act_fn == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif act_fn == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels, out_channels)
|
||||
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
residual = x
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None and hasattr(self, "temb_proj"):
|
||||
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return x + residual, new_cache
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
"""3D downsampling with optional temporal compression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
if t % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
"""3D upsampling with optional temporal decompression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
elif x.shape[2] > 1:
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
else:
|
||||
x = x.squeeze(2)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x[:, :, None, :, :]
|
||||
else:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
|
||||
compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.downsamplers is not None:
|
||||
for ds in self.downsamplers:
|
||||
x = ds(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class MidBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels, out_channels=in_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
|
||||
add_upsample=True, compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.upsamplers is not None:
|
||||
for us in self.upsamplers:
|
||||
x = us(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Encoder3D(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=16,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.down_blocks.append(DownBlock3D(
|
||||
in_channels=input_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
add_downsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=block_out_channels[-1], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
for i, block in enumerate(self.down_blocks):
|
||||
key = f"down_block_{i}"
|
||||
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Decoder3D(nn.Module):
|
||||
def __init__(self, in_channels=16, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
reversed_channels = list(reversed(block_out_channels))
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=reversed_channels[0], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList()
|
||||
output_channel = reversed_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_channel = output_channel
|
||||
output_channel = reversed_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.up_blocks.append(UpBlock3D(
|
||||
in_channels=prev_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block + 1,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels,
|
||||
add_upsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, sample, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
key = f"up_block_{i}"
|
||||
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
|
||||
|
||||
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(nn.Module):
|
||||
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
|
||||
|
||||
Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
|
||||
on the full (low-res) tensor, then the expensive spatial-only up_blocks +
|
||||
norm_out + conv_out are processed in small temporal chunks with conv_cache
|
||||
carrying causal state between chunks. This keeps peak VRAM proportional to
|
||||
chunk_size rather than total frame count.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
latent_channels=16, layers_per_block=3,
|
||||
act_fn="silu", eps=1e-6, groups=32,
|
||||
temporal_compression_ratio=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
self.encoder = Encoder3D(
|
||||
in_channels=in_channels, out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
self.decoder = Decoder3D(
|
||||
in_channels=latent_channels, out_channels=out_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
|
||||
self.num_latent_frames_batch_size = 2
|
||||
self.num_sample_frames_batch_size = 8
|
||||
|
||||
def encode(self, x):
|
||||
t = x.shape[2]
|
||||
frame_batch = self.num_sample_frames_batch_size
|
||||
remainder = t % frame_batch
|
||||
conv_cache = None
|
||||
enc = []
|
||||
|
||||
# Process remainder frames first so only the first chunk can have an
|
||||
# odd temporal dimension — where Downsample3D's first-frame-special
|
||||
# handling in temporal compression is actually correct.
|
||||
if remainder > 0:
|
||||
chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
|
||||
for start in range(remainder, t, frame_batch):
|
||||
chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
|
||||
enc = torch.cat(enc, dim=2)
|
||||
mean, _ = enc.chunk(2, dim=1)
|
||||
return mean
|
||||
|
||||
def decode(self, z):
|
||||
return self._decode_rolling(z)
|
||||
|
||||
def _decode_batched(self, z):
|
||||
"""Original batched decode - processes 2 latent frames through full decoder."""
|
||||
t = z.shape[2]
|
||||
frame_batch = self.num_latent_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
dec = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
|
||||
dec.append(chunk.cpu())
|
||||
return torch.cat(dec, dim=2).to(z.device)
|
||||
|
||||
def _decode_rolling(self, z):
|
||||
"""Rolling decode - processes low-res layers on full tensor, then rolls
|
||||
through expensive high-res layers in temporal chunks."""
|
||||
decoder = self.decoder
|
||||
device = z.device
|
||||
|
||||
# Determine which up_blocks have temporal upsample vs spatial-only.
|
||||
# Temporal up_blocks are cheap (low res), spatial-only are expensive.
|
||||
temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
|
||||
split_at = temporal_compress_level # first N up_blocks do temporal upsample
|
||||
|
||||
# Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
|
||||
x, _ = decoder.conv_in(z)
|
||||
x, _ = decoder.mid_block(x, None, z)
|
||||
|
||||
for i in range(split_at):
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
|
||||
# Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
|
||||
remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
|
||||
chunk_size = 4 # pixel frames per chunk through high-res layers
|
||||
t_expanded = x.shape[2]
|
||||
|
||||
if t_expanded <= chunk_size or len(remaining_blocks) == 0:
|
||||
# Small enough to process in one go
|
||||
for i in remaining_blocks:
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
x, _ = decoder.norm_out(x, z)
|
||||
x = decoder.conv_act(x)
|
||||
x, _ = decoder.conv_out(x)
|
||||
return x
|
||||
|
||||
# Expand z temporally once to match Phase 2's time dimension.
|
||||
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
|
||||
# for the old approach of pre-interpolating to every pixel resolution).
|
||||
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
|
||||
|
||||
# Process in temporal chunks, interpolating spatially per-chunk to avoid
|
||||
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
|
||||
dec_out = []
|
||||
conv_caches = {}
|
||||
|
||||
for chunk_start in range(0, t_expanded, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
||||
x_chunk = x[:, :, chunk_start:chunk_end]
|
||||
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
|
||||
z_spatial_cache = {}
|
||||
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
cache_key = f"up_block_{i}"
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
|
||||
z_spatial_cache[hw_key] = z_t_chunk
|
||||
else:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
|
||||
conv_caches[cache_key] = new_cache
|
||||
|
||||
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
|
||||
if hw_key not in z_spatial_cache:
|
||||
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
|
||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
|
||||
conv_caches["norm_out"] = new_cache
|
||||
x_chunk = decoder.conv_act(x_chunk)
|
||||
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
||||
conv_caches["conv_out"] = new_cache
|
||||
|
||||
dec_out.append(x_chunk.cpu())
|
||||
del z_spatial_cache
|
||||
|
||||
del x, z_time_expanded
|
||||
return torch.cat(dec_out, dim=2).to(device)
|
||||
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_prefetch
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
|
||||
n_rep = q.shape[-3] // k.shape[-3]
|
||||
k = k.repeat_interleave(n_rep, dim=-3)
|
||||
v = v.repeat_interleave(n_rep, dim=-3)
|
||||
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
if "scale" in kwargs:
|
||||
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
|
||||
query = query * (kwargs["scale"] * dim_head ** 0.5)
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
scale = kwargs.get("scale", dim_head ** -0.5)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
dropout_p=0.0, is_causal=False, **sdpa_extra
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
276
comfy/ldm/wan/ar_model.py
Normal file
276
comfy/ldm/wan/ar_model.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||
|
||||
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||
The difference is purely in the forward pass: this model processes one temporal
|
||||
block at a time and maintains a KV cache across blocks.
|
||||
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
repeat_e,
|
||||
WanModel,
|
||||
WanAttentionBlock,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CausalWanSelfAttention(nn.Module):
|
||||
"""Self-attention with KV cache support for autoregressive inference."""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
|
||||
if kv_cache is None:
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v.view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"]
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"] = new_end
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
self.self_attn = CausalWanSelfAttention(
|
||||
dim, num_heads, window_size, qk_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
|
||||
# Self-attention with optional KV cache
|
||||
x = x.contiguous()
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# Cross-attention with optional caching
|
||||
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||
x_ca = optimized_attention(
|
||||
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||
heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn.o(x_ca)
|
||||
else:
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if crossattn_cache is not None:
|
||||
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||
crossattn_cache["is_init"] = True
|
||||
|
||||
# FFN
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(WanModel):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
Same weight structure as WanModel -- loads identical state dicts.
|
||||
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__(
|
||||
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||
wan_attn_block_class=CausalWanAttentionBlock,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
"""
|
||||
Forward one temporal block for autoregressive inference.
|
||||
|
||||
Args:
|
||||
x: [B, C, block_frames, H, W] input latent for the current block
|
||||
timestep: [B, block_frames] per-frame timesteps
|
||||
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||
start_frame: temporal frame index for RoPE offset
|
||||
kv_caches: list of per-layer KV cache dicts
|
||||
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||
clip_fea: optional CLIP features for I2V
|
||||
|
||||
Returns:
|
||||
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||
"""
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# Text embedding (reuses crossattn_cache after first block)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
# RoPE for current block's temporal position
|
||||
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
kv_cache=kv_caches[i],
|
||||
crossattn_cache=crossattn_caches[i])
|
||||
|
||||
# Head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": 0,
|
||||
})
|
||||
return caches
|
||||
|
||||
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||
"""Create fresh cross-attention caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({"is_init": False})
|
||||
return caches
|
||||
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"] = 0
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
for cache in crossattn_caches:
|
||||
cache["is_init"] = False
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.dim // self.num_heads
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
ar_state = transformer_options.get("ar_state")
|
||||
if ar_state is not None:
|
||||
bs = x.shape[0]
|
||||
block_frames = x.shape[2]
|
||||
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||
return self.forward_block(
|
||||
x=x, timestep=t_per_frame, context=context,
|
||||
start_frame=ar_state["start_frame"],
|
||||
kv_caches=ar_state["kv_caches"],
|
||||
crossattn_caches=ar_state["crossattn_caches"],
|
||||
clip_fea=clip_fea,
|
||||
)
|
||||
|
||||
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||
time_dim_concat=time_dim_concat,
|
||||
transformer_options=transformer_options, **kwargs)
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@ -342,6 +343,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
|
||||
|
||||
if isinstance(model, comfy.model_base.ErnieImage):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
@ -467,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -52,6 +53,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.cogvideo.model
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
@ -81,6 +83,7 @@ class ModelType(Enum):
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
V_PREDICTION_DDPM = 12
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -115,6 +118,8 @@ def model_sampling(model_config, model_type):
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
elif model_type == ModelType.V_PREDICTION_DDPM:
|
||||
c = comfy.model_sampling.V_PREDICTION_DDPM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -210,6 +215,11 @@ class BaseModel(torch.nn.Module):
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["prefetch_dynamic_vbars"] = (
|
||||
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||
)
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
@ -1356,6 +1366,13 @@ class WAN21(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||
self.image_to_video = False
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
@ -1979,3 +1996,59 @@ class ErnieImage(BaseModel):
|
||||
class SAM3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
|
||||
|
||||
class CogVideoX(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
|
||||
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape = list(noise.shape)
|
||||
shape[1] = extra_channels
|
||||
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
|
||||
if noise.ndim == 5 and image.ndim == 5:
|
||||
if image.shape[-3] < noise.shape[-3]:
|
||||
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
|
||||
elif image.shape[-3] > noise.shape[-3]:
|
||||
image = image[:, :, :noise.shape[-3]]
|
||||
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
if image.shape[1] > extra_channels:
|
||||
image = image[:, :extra_channels]
|
||||
elif image.shape[1] < extra_channels:
|
||||
repeats = extra_channels // image.shape[1]
|
||||
remainder = extra_channels % image.shape[1]
|
||||
parts = [image] * repeats
|
||||
if remainder > 0:
|
||||
parts.append(image[:, :remainder])
|
||||
image = torch.cat(parts, dim=1)
|
||||
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
|
||||
if self.diffusion_model.ofs_proj_dim is not None:
|
||||
ofs = kwargs.get("ofs", None)
|
||||
if ofs is None:
|
||||
noise = kwargs.get("noise", None)
|
||||
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
|
||||
out['ofs'] = comfy.conds.CONDRegular(ofs)
|
||||
return out
|
||||
|
||||
@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cogvideox"
|
||||
|
||||
# Extract config from weight shapes
|
||||
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
|
||||
time_embed_dim = norm1_weight.shape[1]
|
||||
dim = norm1_weight.shape[0] // 6
|
||||
|
||||
dit_config["num_attention_heads"] = dim // 64
|
||||
dit_config["attention_head_dim"] = 64
|
||||
dit_config["time_embed_dim"] = time_embed_dim
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Detect in_channels from patch_embed
|
||||
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
|
||||
if patch_proj_key in state_dict_keys:
|
||||
w = state_dict[patch_proj_key]
|
||||
if w.ndim == 4:
|
||||
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
|
||||
dit_config["in_channels"] = w.shape[1]
|
||||
dit_config["patch_size"] = w.shape[2]
|
||||
elif w.ndim == 2:
|
||||
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["patch_size_t"] = 2
|
||||
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
|
||||
|
||||
text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
|
||||
if text_proj_key in state_dict_keys:
|
||||
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
|
||||
|
||||
# Detect OFS embedding
|
||||
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
|
||||
if ofs_key in state_dict_keys:
|
||||
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
|
||||
|
||||
# Detect positional embedding type
|
||||
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
|
||||
if pos_key in state_dict_keys:
|
||||
dit_config["use_learned_positional_embeddings"] = True
|
||||
dit_config["use_rotary_positional_embeddings"] = False
|
||||
else:
|
||||
dit_config["use_learned_positional_embeddings"] = False
|
||||
dit_config["use_rotary_positional_embeddings"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -112,10 +113,6 @@ if args.directml is not None:
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
@ -583,9 +580,6 @@ class LoadedModel:
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
@ -727,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models_temp = set()
|
||||
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||
models_temp = {}
|
||||
for m in models:
|
||||
models_temp.add(m)
|
||||
models_temp[m] = None
|
||||
for mm in m.model_patches_models():
|
||||
models_temp.add(mm)
|
||||
models_temp[mm] = None
|
||||
|
||||
models = models_temp
|
||||
models = list(models_temp)
|
||||
models.reverse()
|
||||
|
||||
models_to_load = []
|
||||
|
||||
@ -1182,6 +1178,10 @@ stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1215,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_aimdo_cast_buffer(offload_stream, device):
|
||||
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1581,10 +1594,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1650,10 +1660,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.is_bf16_supported()
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1784,6 +1791,7 @@ def soft_empty_cache(force=False):
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
|
||||
66
comfy/model_prefetch.py
Normal file
66
comfy/model_prefetch.py
Normal file
@ -0,0 +1,66 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def cleanup_prefetched_modules(comfy_modules):
|
||||
for s in comfy_modules:
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
if prefetch is None:
|
||||
continue
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
if prefetch["signature"] is not None:
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
delattr(s, "_prefetch")
|
||||
|
||||
def cleanup_prefetch_queues():
|
||||
global PREFETCH_QUEUES
|
||||
|
||||
for queue in PREFETCH_QUEUES:
|
||||
for entry in queue:
|
||||
if entry is None or not isinstance(entry, tuple):
|
||||
continue
|
||||
_, prefetch_state = entry
|
||||
comfy_modules = prefetch_state[1]
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def prefetch_queue_pop(queue, device, module):
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
consumed = queue.pop(0)
|
||||
if consumed is not None:
|
||||
offload_stream, prefetch_state = consumed
|
||||
if offload_stream is not None:
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
_, comfy_modules = prefetch_state
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
comfy_modules = []
|
||||
for s in prefetch.modules():
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
def make_prefetch_queue(queue, device, transformer_options):
|
||||
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||
or comfy.model_management.NUM_STREAMS == 0
|
||||
or comfy.model_management.is_device_cpu(device)
|
||||
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||
return None
|
||||
|
||||
queue = [None] + queue + [None]
|
||||
PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class V_PREDICTION_DDPM:
|
||||
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
|
||||
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
|
||||
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
|
||||
"""
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
|
||||
269
comfy/ops.py
269
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
# FIXME: add n=1 cache hit fast path
|
||||
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
|
||||
if offload_stream is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||
return
|
||||
|
||||
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = None
|
||||
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||
|
||||
def get_cast_buffer(buffer_size):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
nonlocal cast_buffer_offset
|
||||
|
||||
if buffer_size == 0:
|
||||
return None
|
||||
|
||||
if offload_stream is None:
|
||||
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
prefetch = {
|
||||
"signature": signature,
|
||||
"resident": resident,
|
||||
}
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
s._prefetch = prefetch
|
||||
continue
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
needs_cast = False
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
needs_cast = True
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
prefetch["cast_geometry"] = cast_geometry
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
|
||||
if prefetch["resident"]:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = prefetch["xfer_dest"]
|
||||
if prefetch["needs_cast"]:
|
||||
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
if prefetch["signature"] is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
s._v_signature = prefetch["signature"]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
@ -205,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
update_weight = prefetch["signature"] is not None
|
||||
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||
if bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
if prefetch["signature"] is not None:
|
||||
prefetch["resident"] = True
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
@ -230,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
def format_return(result, offloadable):
|
||||
weight, bias, offload_stream = result
|
||||
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
if not prefetched:
|
||||
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||
|
||||
if not prefetched:
|
||||
if getattr(s, "_prefetch")["signature"] is not None:
|
||||
offload_device = device
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
delattr(s, "_prefetch")
|
||||
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -280,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
@ -1173,6 +1249,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.float()
|
||||
manually_loaded_keys.append(scale_key)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||
if isinstance(qdata, QuantizedTensor):
|
||||
scale = qdata._params.scale
|
||||
qdata = qdata._qdata
|
||||
else:
|
||||
scale = None
|
||||
|
||||
x = torch.nn.functional.embedding(
|
||||
input, qdata, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
|
||||
x = x.to(dtype=target_dtype)
|
||||
if scale is not None and scale != 1.0:
|
||||
x = x * scale.to(dtype=target_dtype)
|
||||
return x
|
||||
|
||||
# Fallback for non-quantized or weight_function (LoRA) case
|
||||
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
from comfy_kitchen.tensor import (
|
||||
@ -21,7 +23,15 @@ try:
|
||||
ck.registry.disable("cuda")
|
||||
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
|
||||
|
||||
ck.registry.disable("triton")
|
||||
if args.enable_triton_backend:
|
||||
try:
|
||||
import triton
|
||||
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
|
||||
ck.registry.disable("triton")
|
||||
else:
|
||||
ck.registry.disable("triton")
|
||||
for k, v in ck.list_backends().items():
|
||||
logging.info(f"Found comfy_kitchen backend {k}: {v}")
|
||||
except ImportError as e:
|
||||
|
||||
@ -3,6 +3,7 @@ import comfy.model_management
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
|
||||
@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
|
||||
gligen += get_models_from_cond(conds[k], "gligen")
|
||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||
|
||||
control_nets = set(cnets)
|
||||
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||
control_nets = list(dict.fromkeys(cnets))
|
||||
|
||||
inference_memory = 0
|
||||
control_models = []
|
||||
|
||||
39
comfy/sd.py
39
comfy/sd.py
@ -18,6 +18,7 @@ import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
import comfy.ldm.mmaudio.vae.autoencoder
|
||||
import comfy.pixel_space_convert
|
||||
@ -64,6 +65,8 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -478,7 +481,10 @@ class VAE:
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
|
||||
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
|
||||
elif "taesd_decoder.1.weight" in sd:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
|
||||
self.latent_channels = metadata["tae_latent_channels"]
|
||||
else:
|
||||
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
|
||||
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
|
||||
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
|
||||
self.first_stage_model = StageA()
|
||||
@ -652,6 +658,17 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
|
||||
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@ -1208,6 +1225,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
|
||||
|
||||
|
||||
@ -1256,6 +1274,9 @@ class TEModel(Enum):
|
||||
QWEN35_9B = 26
|
||||
QWEN35_27B = 27
|
||||
MINISTRAL_3_3B = 28
|
||||
GEMMA_4_E4B = 29
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1281,6 +1302,12 @@ def detect_te_model(sd):
|
||||
return TEModel.BYT5_SMALL_GLYPH
|
||||
return TEModel.T5_BASE
|
||||
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
|
||||
if 'model.layers.59.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_4_31B
|
||||
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E4B
|
||||
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
|
||||
return TEModel.GEMMA_4_E2B
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1403,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.COGVIDEOX:
|
||||
clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
@ -1420,6 +1450,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
|
||||
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
|
||||
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
|
||||
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
|
||||
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
|
||||
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
|
||||
@ -27,6 +27,7 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -1166,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"causal_ar": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.unet_config.pop("causal_ar", None)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.WAN21_CausalAR(self, device=device)
|
||||
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1832,6 +1852,156 @@ class SAM31(SAM3):
|
||||
unet_config = {"image_model": "SAM31"}
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
|
||||
class CogVideoX_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
}
|
||||
|
||||
models += [SVD_img2vid]
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.012,
|
||||
"beta_schedule": "linear",
|
||||
"zsnr": True,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.CogVideoX
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
# 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
|
||||
# 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
|
||||
# Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
|
||||
if unet_config.get("num_attention_heads", 0) >= 48:
|
||||
self.latent_format = latent_formats.CogVideoX1_5
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
|
||||
|
||||
class CogVideoX_I2V(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CogVideoX_Inpaint(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [
|
||||
LotusD,
|
||||
Stable_Zero123,
|
||||
SD15_instructpix2pix,
|
||||
SD15,
|
||||
SD20,
|
||||
SD21UnclipL,
|
||||
SD21UnclipH,
|
||||
SDXL_instructpix2pix,
|
||||
SDXLRefiner,
|
||||
SDXL,
|
||||
SSD1B,
|
||||
KOALA_700M,
|
||||
KOALA_1B,
|
||||
Segmind_Vega,
|
||||
SD_X4Upscaler,
|
||||
Stable_Cascade_C,
|
||||
Stable_Cascade_B,
|
||||
SV3D_u,
|
||||
SV3D_p,
|
||||
SD3,
|
||||
StableAudio,
|
||||
AuraFlow,
|
||||
PixArtAlpha,
|
||||
PixArtSigma,
|
||||
HunyuanDiT,
|
||||
HunyuanDiT1,
|
||||
FluxInpaint,
|
||||
Flux,
|
||||
LongCatImage,
|
||||
FluxSchnell,
|
||||
GenmoMochi,
|
||||
LTXV,
|
||||
LTXAV,
|
||||
HunyuanVideo15_SR_Distilled,
|
||||
HunyuanVideo15,
|
||||
HunyuanImage21Refiner,
|
||||
HunyuanImage21,
|
||||
HunyuanVideoSkyreelsI2V,
|
||||
HunyuanVideoI2V,
|
||||
HunyuanVideo,
|
||||
CosmosT2V,
|
||||
CosmosI2V,
|
||||
CosmosT2IPredict2,
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
WAN21_T2V,
|
||||
WAN21_I2V,
|
||||
WAN21_FunControl2V,
|
||||
WAN21_Vace,
|
||||
WAN21_Camera,
|
||||
WAN22_Camera,
|
||||
WAN22_S2V,
|
||||
WAN21_HuMo,
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
HiDream,
|
||||
Chroma,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
RT_DETR_v4,
|
||||
ErnieImage,
|
||||
SAM3,
|
||||
SAM31,
|
||||
CogVideoX_Inpaint,
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
]
|
||||
|
||||
@ -7,6 +7,7 @@ from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
operations=comfy.ops.disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
@ -47,11 +48,14 @@ class TGrow(nn.Module):
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
|
||||
patch_size=1, decode=False):
|
||||
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
if not decode and patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, patch_size)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
x = x.view(B, T, C, H, W)
|
||||
if decode and patch_size > 1:
|
||||
x = F.pixel_shuffle(x, patch_size)
|
||||
x = x.view(B, x.shape[0] // B, *x.shape[1:])
|
||||
x = x.to(output_device)
|
||||
else:
|
||||
out = []
|
||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
||||
# Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
|
||||
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
|
||||
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.popleft()
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if not decode and patch_size > 1:
|
||||
xt = F.pixel_unshuffle(xt, patch_size)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
if decode and patch_size > 1:
|
||||
xt = F.pixel_shuffle(xt, patch_size)
|
||||
out.append(xt.to(output_device))
|
||||
del xt
|
||||
else:
|
||||
b = model[i]
|
||||
@ -165,24 +176,20 @@ class TAEHV(nn.Module):
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if self.patch_size > 1:
|
||||
B, T, C, H, W = x.shape
|
||||
x = x.reshape(B * T, C, H, W)
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
|
||||
if x.shape[1] % self.t_downscale != 0:
|
||||
# pad at end to multiple of t_downscale
|
||||
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
|
||||
patch_size=self.patch_size).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
|
||||
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
|
||||
output_device=comfy.model_management.intermediate_device(),
|
||||
patch_size=self.patch_size, decode=True)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
@ -17,32 +17,79 @@ class Clamp(nn.Module):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_in, n_out):
|
||||
def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
|
||||
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.fuse = nn.ReLU()
|
||||
def forward(self, x):
|
||||
if not use_midblock_gn:
|
||||
self.pool = None
|
||||
return
|
||||
n_gn = n_in * 4
|
||||
self.pool = nn.Sequential(
|
||||
comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
|
||||
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
|
||||
nn.ReLU(inplace=True),
|
||||
comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.pool is not None:
|
||||
x = x + self.pool(x)
|
||||
return self.fuse(self.conv(x) + self.skip(x))
|
||||
|
||||
def Encoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
class Encoder(nn.Sequential):
|
||||
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||
super().__init__(
|
||||
conv(3, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
|
||||
conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
|
||||
conv(64, latent_channels),
|
||||
)
|
||||
|
||||
class Decoder(nn.Sequential):
|
||||
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
|
||||
super().__init__(
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class DecoderFlux2(Decoder):
|
||||
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||
if latent_channels != 128 or not use_gn:
|
||||
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||
super().__init__(latent_channels=32, use_gn=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, C, H, W = x.shape
|
||||
x = (
|
||||
x
|
||||
.reshape(B, 32, 2, 2, H, W)
|
||||
.permute(0, 1, 4, 2, 5, 3)
|
||||
.reshape(B, 32, H * 2, W * 2)
|
||||
)
|
||||
return super().forward(x)
|
||||
|
||||
class EncoderFlux2(Encoder):
|
||||
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
|
||||
if latent_channels != 128 or not use_gn:
|
||||
raise ValueError("Unexpected parameters for Flux2 TAE module")
|
||||
super().__init__(latent_channels=32, use_gn=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
result = super().forward(x)
|
||||
B, C, H, W = result.shape
|
||||
return (
|
||||
result
|
||||
.reshape(B, C, H // 2, 2, W // 2, 2)
|
||||
.permute(0, 1, 3, 5, 2, 4)
|
||||
.reshape(B, 128, H // 2, W // 2)
|
||||
)
|
||||
|
||||
def Decoder(latent_channels=4):
|
||||
return nn.Sequential(
|
||||
Clamp(), conv(latent_channels, 64), nn.ReLU(),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
|
||||
Block(64, 64), conv(64, 3),
|
||||
)
|
||||
|
||||
class TAESD(nn.Module):
|
||||
latent_magnitude = 3
|
||||
@ -51,8 +98,15 @@ class TAESD(nn.Module):
|
||||
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
|
||||
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
|
||||
super().__init__()
|
||||
self.taesd_encoder = Encoder(latent_channels=latent_channels)
|
||||
self.taesd_decoder = Decoder(latent_channels=latent_channels)
|
||||
if latent_channels == 128:
|
||||
encoder_class = EncoderFlux2
|
||||
decoder_class = DecoderFlux2
|
||||
else:
|
||||
encoder_class = Encoder
|
||||
decoder_class = Decoder
|
||||
self.taesd_encoder = encoder_class(latent_channels=latent_channels)
|
||||
self.taesd_decoder = decoder_class(latent_channels=latent_channels)
|
||||
|
||||
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
|
||||
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
|
||||
if encoder_path is not None:
|
||||
@ -61,19 +115,19 @@ class TAESD(nn.Module):
|
||||
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
|
||||
|
||||
@staticmethod
|
||||
def scale_latents(x):
|
||||
def scale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""raw latents -> [0, 1]"""
|
||||
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def unscale_latents(x):
|
||||
def unscale_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
"""[0, 1] -> raw latents"""
|
||||
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
|
||||
|
||||
def decode(self, x):
|
||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
|
||||
x_sample = x_sample.sub(0.5).mul(2)
|
||||
return x_sample
|
||||
|
||||
def encode(self, x):
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
|
||||
|
||||
48
comfy/text_encoders/cogvideo.py
Normal file
48
comfy/text_encoders/cogvideo.py
Normal file
@ -0,0 +1,48 @@
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
|
||||
"""Inner T5 tokenizer for CogVideoX.
|
||||
|
||||
CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
|
||||
Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
|
||||
the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
|
||||
"""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
|
||||
|
||||
|
||||
class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
|
||||
|
||||
|
||||
class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
|
||||
"""Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
|
||||
|
||||
Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
|
||||
(which reads self.dtypes) works correctly. The inner model is the standard
|
||||
sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
|
||||
"""
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl",
|
||||
clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
|
||||
model_options=model_options)
|
||||
|
||||
|
||||
def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
"""Factory that returns a CogVideoXT5XXL class configured with the detected
|
||||
T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
|
||||
"""
|
||||
class CogVideoXTEModel_(CogVideoXT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CogVideoXTEModel_
|
||||
1298
comfy/text_encoders/gemma4.py
Normal file
1298
comfy/text_encoders/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -521,7 +521,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
@ -533,12 +533,12 @@ class Attention(nn.Module):
|
||||
return self.o_proj(output), present_key_value
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
|
||||
super().__init__()
|
||||
ops = ops or nn
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
intermediate_size = intermediate_size or config.intermediate_size
|
||||
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
if config.mlp_activation == "silu":
|
||||
self.activation = torch.nn.functional.silu
|
||||
elif config.mlp_activation == "gelu_pytorch_tanh":
|
||||
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
|
||||
|
||||
return x, present_key_value
|
||||
|
||||
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
|
||||
class ScaledEmbedding(ops.Embedding):
|
||||
def forward(self, input_ids, out_dtype=None):
|
||||
return super().forward(input_ids, out_dtype=out_dtype) * scale
|
||||
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
|
||||
transformer = TransformerBlockGemma2
|
||||
self.normalize_in = True
|
||||
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
|
||||
else:
|
||||
transformer = TransformerBlock
|
||||
self.normalize_in = False
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
|
||||
self.config.rope_dims,
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
if self.normalize_in:
|
||||
x *= self.config.hidden_size ** 0.5
|
||||
|
||||
seq_len = x.shape[1]
|
||||
past_len = 0
|
||||
if past_key_values is not None and len(past_key_values) > 0:
|
||||
@ -850,7 +848,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -875,14 +873,16 @@ class BaseGenerate:
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
|
||||
@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
embeds, _, _, _ = super().process_tokens(tokens, device)
|
||||
return embeds
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
|
||||
@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.normalize_in = False
|
||||
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||
"""Normalize image embeddings to match text embedding scale"""
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start_idx = info["index"]
|
||||
end_idx = start_idx + info["size"]
|
||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||
|
||||
@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend,
|
||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
|
||||
class FeatureFlagInfo(TypedDict):
|
||||
type: str
|
||||
default: Any
|
||||
description: str
|
||||
|
||||
|
||||
# Registry of known CLI-settable feature flags.
|
||||
# Launchers can query this via --list-feature-flags to discover valid flags.
|
||||
CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
|
||||
"show_signin_button": {
|
||||
"type": "bool",
|
||||
"default": False,
|
||||
"description": "Show the sign-in button in the frontend even when not signed in",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _coerce_bool(v: str) -> bool:
|
||||
"""Strict bool coercion: only 'true'/'false' (case-insensitive).
|
||||
|
||||
Anything else raises ValueError so the caller can warn and drop the flag,
|
||||
rather than silently treating typos like 'ture' or 'yes' as False.
|
||||
"""
|
||||
lower = v.lower()
|
||||
if lower == "true":
|
||||
return True
|
||||
if lower == "false":
|
||||
return False
|
||||
raise ValueError(f"expected 'true' or 'false', got {v!r}")
|
||||
|
||||
|
||||
_COERCE_FNS: dict[str, Any] = {
|
||||
"bool": _coerce_bool,
|
||||
"int": lambda v: int(v),
|
||||
"float": lambda v: float(v),
|
||||
}
|
||||
|
||||
|
||||
def _coerce_flag_value(key: str, raw_value: str) -> Any:
|
||||
"""Coerce a raw string value using the registry type, or keep as string.
|
||||
|
||||
Returns the raw string if the key is unregistered or the type is unknown.
|
||||
Raises ValueError/TypeError if the key is registered with a known type but
|
||||
the value cannot be coerced; callers are expected to warn and drop the flag.
|
||||
"""
|
||||
info = CLI_FEATURE_FLAG_REGISTRY.get(key)
|
||||
if info is None:
|
||||
return raw_value
|
||||
coerce = _COERCE_FNS.get(info["type"])
|
||||
if coerce is None:
|
||||
return raw_value
|
||||
return coerce(raw_value)
|
||||
|
||||
|
||||
def _parse_cli_feature_flags() -> dict[str, Any]:
|
||||
"""Parse --feature-flag key=value pairs from CLI args into a dict.
|
||||
|
||||
Items without '=' default to the value 'true' (bare flag form).
|
||||
Flags whose value cannot be coerced to the registered type are dropped
|
||||
with a warning, so a typo like '--feature-flag some_bool=ture' does not
|
||||
silently take effect as the wrong value.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
for item in getattr(args, "feature_flag", []):
|
||||
key, sep, raw_value = item.partition("=")
|
||||
key = key.strip()
|
||||
if not key:
|
||||
continue
|
||||
if not sep:
|
||||
raw_value = "true"
|
||||
try:
|
||||
result[key] = _coerce_flag_value(key, raw_value.strip())
|
||||
except (ValueError, TypeError) as e:
|
||||
info = CLI_FEATURE_FLAG_REGISTRY.get(key, {})
|
||||
logging.warning(
|
||||
"Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.",
|
||||
key, raw_value.strip(), info.get("type", "?"), e,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"assets": args.enable_assets,
|
||||
}
|
||||
|
||||
# CLI-provided flags cannot overwrite core flags
|
||||
_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS}
|
||||
|
||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags}
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
|
||||
@ -251,6 +251,7 @@ class VideoFromFile(VideoInput):
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -283,11 +284,25 @@ class VideoFromFile(VideoInput):
|
||||
break
|
||||
|
||||
if not checked_alpha:
|
||||
alpha_channel = False
|
||||
for comp in frame.format.components:
|
||||
if comp.is_alpha or frame.format.name == "pal8":
|
||||
alphas = []
|
||||
image_format = 'gbrapf32le'
|
||||
alpha_channel = True
|
||||
break
|
||||
if frame.format.name in ("yuvj420p", "yuvj422p", "yuvj444p", "rgb24", "rgba", "pal8"):
|
||||
process_image_format = lambda a: a.float() / 255.0
|
||||
if alpha_channel:
|
||||
image_format = 'rgba'
|
||||
else:
|
||||
image_format = 'rgb24'
|
||||
else:
|
||||
process_image_format = lambda a: a
|
||||
if alpha_channel:
|
||||
image_format = 'gbrapf32le'
|
||||
else:
|
||||
image_format = 'gbrpf32le'
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
@ -323,9 +338,9 @@ class VideoFromFile(VideoInput):
|
||||
else:
|
||||
audio_frames.append(frame.to_ndarray())
|
||||
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
|
||||
images = process_image_format(torch.stack(frames)) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
|
||||
if alphas is not None:
|
||||
alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
|
||||
alphas = process_image_format(torch.stack(alphas)) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
|
||||
|
||||
# Get frame rate
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
@ -395,7 +395,6 @@ class Combo(ComfyTypeIO):
|
||||
@comfytype(io_type="COMBO")
|
||||
class MultiCombo(ComfyTypeI):
|
||||
'''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
|
||||
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
@ -408,12 +407,14 @@ class MultiCombo(ComfyTypeI):
|
||||
self.default: list[str]
|
||||
|
||||
def as_dict(self):
|
||||
to_return = super().as_dict() | prune_dict({
|
||||
"multi_select": self.multiselect,
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
# Frontend expects `multi_select` to be an object config (not a boolean).
|
||||
# Keep top-level `multiselect` from Combo.Input for backwards compatibility.
|
||||
return super().as_dict() | prune_dict({
|
||||
"multi_select": prune_dict({
|
||||
"placeholder": self.placeholder,
|
||||
"chip": self.chip,
|
||||
}),
|
||||
})
|
||||
return to_return
|
||||
|
||||
@comfytype(io_type="IMAGE")
|
||||
class Image(ComfyTypeIO):
|
||||
|
||||
@ -157,6 +157,11 @@ class SeedanceCreateAssetResponse(BaseModel):
|
||||
asset_id: str = Field(...)
|
||||
|
||||
|
||||
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
|
||||
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
|
||||
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
SEEDANCE2_PRICE_PER_1K_TOKENS = {
|
||||
("dreamina-seedance-2-0-260128", False): 0.007,
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, confloat
|
||||
|
||||
|
||||
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
@ -183,13 +180,13 @@ class LumaAssets(BaseModel):
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
'''Used for image gen'''
|
||||
"""Used for image gen"""
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
'''Used for video gen'''
|
||||
"""Used for video gen"""
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
@ -251,3 +248,32 @@ class LumaGeneration(BaseModel):
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||
|
||||
|
||||
class Luma2ImageRef(BaseModel):
|
||||
url: str | None = None
|
||||
data: str | None = None
|
||||
media_type: str | None = None
|
||||
|
||||
|
||||
class Luma2GenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., min_length=1, max_length=6000)
|
||||
model: str | None = None
|
||||
type: str | None = None
|
||||
aspect_ratio: str | None = None
|
||||
style: str | None = None
|
||||
output_format: str | None = None
|
||||
web_search: bool | None = None
|
||||
image_ref: list[Luma2ImageRef] | None = None
|
||||
source: Luma2ImageRef | None = None
|
||||
|
||||
|
||||
class Luma2Generation(BaseModel):
|
||||
id: str | None = None
|
||||
type: str | None = None
|
||||
state: str | None = None
|
||||
model: str | None = None
|
||||
created_at: str | None = None
|
||||
output: list[LumaImageReference] | None = None
|
||||
failure_reason: str | None = None
|
||||
failure_code: str | None = None
|
||||
|
||||
@ -1,152 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from pydantic import BaseModel, Field, StrictBytes
|
||||
|
||||
|
||||
class MoonvalleyPromptResponse(BaseModel):
|
||||
error: Optional[Dict[str, Any]] = None
|
||||
frame_conditioning: Optional[Dict[str, Any]] = None
|
||||
id: Optional[str] = None
|
||||
inference_params: Optional[Dict[str, Any]] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
model_params: Optional[Dict[str, Any]] = None
|
||||
output_url: Optional[str] = None
|
||||
prompt_text: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
75, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
fps: Optional[int] = Field(
|
||||
24, description='Frames per second of the generated video'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
10, description='Guidance scale for generation control'
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
1080, description='Height of the generated video in pixels'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
0, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
1920, description='Width of the generated video in pixels'
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyTextToVideoRequest(BaseModel):
|
||||
image_url: Optional[str] = None
|
||||
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
|
||||
prompt_text: Optional[str] = None
|
||||
webhook_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileRequest(BaseModel):
|
||||
file: Optional[StrictBytes] = None
|
||||
|
||||
|
||||
class MoonvalleyUploadFileResponse(BaseModel):
|
||||
access_url: Optional[str] = None
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
|
||||
add_quality_guidance: Optional[bool] = Field(
|
||||
True, description='Whether to add quality guidance'
|
||||
)
|
||||
caching_coefficient: Optional[float] = Field(
|
||||
0.3, description='Caching coefficient for optimization'
|
||||
)
|
||||
caching_cooldown: Optional[int] = Field(
|
||||
3, description='Number of caching cooldown steps'
|
||||
)
|
||||
caching_warmup: Optional[int] = Field(
|
||||
3, description='Number of caching warmup steps'
|
||||
)
|
||||
clip_value: Optional[float] = Field(
|
||||
3, description='CLIP value for generation control'
|
||||
)
|
||||
conditioning_frame_index: Optional[int] = Field(
|
||||
0, description='Index of the conditioning frame'
|
||||
)
|
||||
cooldown_steps: Optional[int] = Field(
|
||||
36, description='Number of cooldown steps (calculated based on num_frames)'
|
||||
)
|
||||
guidance_scale: Optional[float] = Field(
|
||||
15, description='Guidance scale for generation control'
|
||||
)
|
||||
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
|
||||
seed: Optional[int] = Field(
|
||||
None, description='Random seed for generation (default: random)'
|
||||
)
|
||||
shift_value: Optional[float] = Field(
|
||||
3, description='Shift value for generation control'
|
||||
)
|
||||
steps: Optional[int] = Field(80, description='Number of denoising steps')
|
||||
use_guidance_schedule: Optional[bool] = Field(
|
||||
True, description='Whether to use guidance scheduling'
|
||||
)
|
||||
use_negative_prompts: Optional[bool] = Field(
|
||||
False, description='Whether to use negative prompts'
|
||||
)
|
||||
use_timestep_transform: Optional[bool] = Field(
|
||||
True, description='Whether to use timestep transformation'
|
||||
)
|
||||
warmup_steps: Optional[int] = Field(
|
||||
24, description='Number of warmup steps (calculated based on num_frames)'
|
||||
)
|
||||
|
||||
|
||||
class ControlType(str, Enum):
|
||||
motion_control = 'motion_control'
|
||||
pose_control = 'pose_control'
|
||||
|
||||
|
||||
class MoonvalleyVideoToVideoRequest(BaseModel):
|
||||
control_type: ControlType = Field(
|
||||
..., description='Supported types for video control'
|
||||
)
|
||||
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
|
||||
prompt_text: str = Field(..., description='Describes the video to generate')
|
||||
video_url: str = Field(..., description='Url to control video')
|
||||
webhook_url: Optional[str] = Field(
|
||||
None, description='Optional webhook URL for notifications'
|
||||
)
|
||||
@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel):
|
||||
instructions: str | None = Field(None)
|
||||
max_output_tokens: int | None = Field(None)
|
||||
model: str | None = Field(None)
|
||||
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0)
|
||||
temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0)
|
||||
top_p: float | None = Field(
|
||||
1,
|
||||
None,
|
||||
description="Controls diversity of the response via nucleus sampling",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'")
|
||||
truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'")
|
||||
|
||||
|
||||
class ResponseProperties(BaseModel):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
|
||||
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
|
||||
realism: float | None = Field(None, description="ast-2 realism control")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
@ -90,7 +93,7 @@ class Overrides(BaseModel):
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
@ -20,6 +21,7 @@ from comfy_api_nodes.apis.bytedance import (
|
||||
SeedanceCreateAssetResponse,
|
||||
SeedanceCreateVisualValidateSessionResponse,
|
||||
SeedanceGetVisualValidateSessionResponse,
|
||||
SeedanceVirtualLibraryCreateAssetRequest,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskAudioContent,
|
||||
@ -271,6 +273,30 @@ async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_i
|
||||
)
|
||||
|
||||
|
||||
async def _seedance_virtual_library_upload_image_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
wait_label: str = "Uploading image",
|
||||
) -> str:
|
||||
"""Upload an image into the caller's per-customer Seedance virtual library."""
|
||||
public_url = await upload_image_to_comfyapi(cls, image, wait_label=wait_label)
|
||||
normalized = image.detach().cpu().contiguous().to(torch.float32)
|
||||
digest = hashlib.sha256()
|
||||
digest.update(str(tuple(normalized.shape)).encode("utf-8"))
|
||||
digest.update(b"\0")
|
||||
digest.update(normalized.numpy().tobytes())
|
||||
image_hash = digest.hexdigest()
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=image_hash),
|
||||
)
|
||||
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
@ -1377,7 +1403,6 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
max_poll_attempts=180,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
@ -1507,7 +1532,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
if first_frame_asset_id:
|
||||
first_frame_url = image_assets[first_frame_asset_id]
|
||||
else:
|
||||
first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.")
|
||||
first_frame_url = await _seedance_virtual_library_upload_image_asset(
|
||||
cls, first_frame, wait_label="Uploading first frame."
|
||||
)
|
||||
|
||||
content: list[TaskTextContent | TaskImageContent] = [
|
||||
TaskTextContent(text=model["prompt"]),
|
||||
@ -1527,7 +1554,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.")
|
||||
url=await _seedance_virtual_library_upload_image_asset(
|
||||
cls, last_frame, wait_label="Uploading last frame."
|
||||
)
|
||||
),
|
||||
role="last_frame",
|
||||
),
|
||||
@ -1555,7 +1584,6 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
|
||||
poll_interval=9,
|
||||
max_poll_attempts=180,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
@ -1805,9 +1833,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
content.append(
|
||||
TaskImageContent(
|
||||
image_url=TaskImageContentUrl(
|
||||
url=await upload_image_to_comfyapi(
|
||||
url=await _seedance_virtual_library_upload_image_asset(
|
||||
cls,
|
||||
image=reference_images[key],
|
||||
reference_images[key],
|
||||
wait_label=f"Uploading image {i}",
|
||||
),
|
||||
),
|
||||
@ -1877,7 +1905,6 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
status_extractor=lambda r: r.status,
|
||||
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
|
||||
poll_interval=9,
|
||||
max_poll_attempts=180,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
@ -178,7 +178,6 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
|
||||
|
||||
@ -324,7 +323,6 @@ class HitPawVideoEnhance(IO.ComfyNode):
|
||||
status_extractor=lambda x: x.data.status,
|
||||
price_extractor=lambda x: request_price,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
|
||||
|
||||
|
||||
@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
@ -3062,7 +3061,6 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=poll_path),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
@ -3188,7 +3186,6 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
max_poll_attempts=280,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.luma import (
|
||||
Luma2Generation,
|
||||
Luma2GenerationRequest,
|
||||
Luma2ImageRef,
|
||||
LumaAspectRatio,
|
||||
LumaCharacterRef,
|
||||
LumaConceptChain,
|
||||
@ -30,6 +31,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||
style_image: Optional[torch.Tensor] = None,
|
||||
character_image: Optional[torch.Tensor] = None,
|
||||
image_luma_ref: LumaReferenceChain | None = None,
|
||||
style_image: torch.Tensor | None = None,
|
||||
character_image: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
# handle image_luma_ref
|
||||
@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: Optional[LumaConceptChain] = None,
|
||||
luma_concepts: LumaConceptChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=PRICE_BADGE_VIDEO,
|
||||
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _luma2_uni1_common_inputs(max_image_refs: int) -> list:
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"style",
|
||||
options=["auto", "manga"],
|
||||
default="auto",
|
||||
tooltip="Style preset. 'auto' picks based on the prompt; "
|
||||
"'manga' applies a manga/anime aesthetic and requires a portrait "
|
||||
"aspect ratio (2:3, 9:16, 1:2, 1:3).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"web_search",
|
||||
default=False,
|
||||
tooltip="Search the web for visual references before generating.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"image_ref",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, max_image_refs + 1)],
|
||||
min=0,
|
||||
),
|
||||
optional=True,
|
||||
tooltip=f"Up to {max_image_refs} reference images for style/content guidance.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _luma2_upload_image_refs(
|
||||
cls: type[IO.ComfyNode],
|
||||
refs: dict | None,
|
||||
max_count: int,
|
||||
) -> list[Luma2ImageRef] | None:
|
||||
if not refs:
|
||||
return None
|
||||
out: list[Luma2ImageRef] = []
|
||||
for key in refs:
|
||||
url = await upload_image_to_comfyapi(cls, refs[key])
|
||||
out.append(Luma2ImageRef(url=url))
|
||||
if len(out) > max_count:
|
||||
raise ValueError(f"Maximum {max_count} reference images are allowed.")
|
||||
return out or None
|
||||
|
||||
|
||||
async def _luma2_submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Luma2GenerationRequest,
|
||||
) -> Input.Image:
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
|
||||
response_model=Luma2Generation,
|
||||
data=request,
|
||||
)
|
||||
if not initial.id:
|
||||
raise RuntimeError("Luma 2 API did not return a generation id.")
|
||||
final = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
|
||||
response_model=Luma2Generation,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: None,
|
||||
)
|
||||
if not final.output:
|
||||
msg = final.failure_reason or "no output returned"
|
||||
raise RuntimeError(f"Luma 2 generation failed: {msg}")
|
||||
url = final.output[0].url
|
||||
if not url:
|
||||
raise RuntimeError("Luma 2 generation completed without an output URL.")
|
||||
return await download_url_to_image_tensor(url)
|
||||
|
||||
|
||||
class LumaImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode2",
|
||||
display_name="Luma UNI-1 Image",
|
||||
category="api node/image/Luma",
|
||||
description="Generate images from text using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. 1–6000 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"3:1",
|
||||
"2:1",
|
||||
"16:9",
|
||||
"3:2",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"9:16",
|
||||
"1:2",
|
||||
"1:3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="Output image aspect ratio. 'auto' lets "
|
||||
"the model pick based on the prompt.",
|
||||
),
|
||||
*_luma2_uni1_common_inputs(max_image_refs=9),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1-max",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"3:1",
|
||||
"2:1",
|
||||
"16:9",
|
||||
"3:2",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"9:16",
|
||||
"1:2",
|
||||
"1:3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="Output image aspect ratio. 'auto' lets "
|
||||
"the model pick based on the prompt.",
|
||||
),
|
||||
*_luma2_uni1_common_inputs(max_image_refs=9),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$refs := $lookup(inputGroups, "model.image_ref");
|
||||
$base := $m = "uni-1-max" ? 0.1 : 0.0404;
|
||||
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=6000)
|
||||
aspect_ratio = model["aspect_ratio"]
|
||||
style = model["style"]
|
||||
allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"}
|
||||
if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios:
|
||||
raise ValueError(
|
||||
f"'manga' style requires a portrait aspect ratio "
|
||||
f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'."
|
||||
)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model["model"],
|
||||
type="image",
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None,
|
||||
style=style if style != "auto" else None,
|
||||
output_format="png",
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageEditNode2",
|
||||
display_name="Luma UNI-1 Image Edit",
|
||||
category="api node/image/Luma",
|
||||
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"source",
|
||||
tooltip="Source image to edit.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Description of the desired edit. 1–6000 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1",
|
||||
_luma2_uni1_common_inputs(max_image_refs=8),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"uni-1-max",
|
||||
_luma2_uni1_common_inputs(max_image_refs=8),
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for editing.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$refs := $lookup(inputGroups, "model.image_ref");
|
||||
$base := $m = "uni-1-max" ? 0.103 : 0.0434;
|
||||
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
source: Input.Image,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=6000)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model["model"],
|
||||
type="image_edit",
|
||||
source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)),
|
||||
style=model["style"] if model["style"] != "auto" else None,
|
||||
output_format="png",
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension):
|
||||
LumaImageToVideoGenerationNode,
|
||||
LumaReferenceNode,
|
||||
LumaConceptsNode,
|
||||
LumaImageNode,
|
||||
LumaImageEditNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -230,7 +230,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
@ -391,7 +390,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
@ -541,7 +539,6 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
@ -782,7 +779,6 @@ class MagnificImageRelightNode(IO.ComfyNode):
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
@ -924,7 +920,6 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
|
||||
|
||||
|
||||
@ -1,534 +0,0 @@
|
||||
import logging
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.moonvalley import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
trim_video,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
|
||||
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
|
||||
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
|
||||
|
||||
MIN_WIDTH = 300
|
||||
MIN_HEIGHT = 300
|
||||
|
||||
MAX_WIDTH = 10000
|
||||
MAX_HEIGHT = 10000
|
||||
|
||||
MIN_VID_WIDTH = 300
|
||||
MIN_VID_HEIGHT = 300
|
||||
|
||||
MAX_VID_WIDTH = 10000
|
||||
MAX_VID_HEIGHT = 10000
|
||||
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
"""Verifies that the initial response contains a task ID."""
|
||||
return bool(response.id)
|
||||
|
||||
|
||||
def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
Args:
|
||||
video: Input video to validate
|
||||
|
||||
Returns:
|
||||
Validated and potentially trimmed video
|
||||
|
||||
Raises:
|
||||
ValueError: If video doesn't meet requirements
|
||||
MoonvalleyApiError: If video duration is too short
|
||||
"""
|
||||
width, height = _get_video_dimensions(video)
|
||||
_validate_video_dimensions(width, height)
|
||||
validate_container_format_is_mp4(video)
|
||||
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
except Exception as e:
|
||||
logging.error("Error getting dimensions of video: %s", e)
|
||||
raise ValueError(f"Cannot get video dimensions: {e}") from e
|
||||
|
||||
|
||||
def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
"""Validates video dimensions meet Moonvalley V2V requirements."""
|
||||
supported_resolutions = {
|
||||
(1920, 1080),
|
||||
(1080, 1920),
|
||||
(1152, 1152),
|
||||
(1536, 1152),
|
||||
(1152, 1536),
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
return _trim_if_too_long(video, duration)
|
||||
|
||||
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
return video
|
||||
|
||||
|
||||
def parse_width_height_from_res(resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
|
||||
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
|
||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
||||
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||
}
|
||||
return res_map.get(resolution, {"width": 1920, "height": 1080})
|
||||
|
||||
|
||||
def parse_control_parameter(value):
|
||||
control_map = {
|
||||
"Motion Transfer": "motion_control",
|
||||
"Canny": "canny_control",
|
||||
"Pose Transfer": "pose_control",
|
||||
"Depth": "depth_control",
|
||||
}
|
||||
return control_map.get(value, control_map["Motion Transfer"])
|
||||
|
||||
|
||||
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
|
||||
|
||||
class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyImg2VideoNode",
|
||||
display_name="Moonvalley Marey Image to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="Moonvalley Marey Image to Video Node",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The reference image used to generate the video",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
"9:16 (1080 x 1920)",
|
||||
"1:1 (1152 x 1152)",
|
||||
"4:3 (1536 x 1152)",
|
||||
"3:4 (1152 x 1536)",
|
||||
# "21:9 (2560 x 1080)",
|
||||
],
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=4.5,
|
||||
min=1.0,
|
||||
max=20.0,
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Number of denoising steps",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 1.5}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
guidance_scale=prompt_adherence,
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
use_negative_prompts=True,
|
||||
)
|
||||
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
),
|
||||
)
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
|
||||
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyVideo2VideoNode",
|
||||
display_name="Moonvalley Marey Video to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Describes the video to generate",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed value",
|
||||
control_after_generate=False,
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
||||
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"control_type",
|
||||
options=["Motion Transfer", "Pose Transfer"],
|
||||
default="Motion Transfer",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"motion_intensity",
|
||||
default=100,
|
||||
min=0,
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Only used if control_type is 'Motion Transfer'",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=60,
|
||||
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Number of inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 2.25}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: int | None = 100,
|
||||
steps=60,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
if control_type == "Motion Transfer" and motion_intensity is not None:
|
||||
control_params["motion_intensity"] = motion_intensity
|
||||
|
||||
inference_params = MoonvalleyVideoToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
control_params=control_params,
|
||||
steps=steps,
|
||||
guidance_scale=prompt_adherence,
|
||||
)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyVideoToVideoRequest(
|
||||
control_type=parse_control_parameter(control_type),
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
),
|
||||
)
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MoonvalleyTxt2VideoNode",
|
||||
display_name="Moonvalley Marey Text to Video",
|
||||
category="api node/video/Moonvalley Marey",
|
||||
description="",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||
tooltip="Negative prompt text",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"16:9 (1920 x 1080)",
|
||||
"9:16 (1080 x 1920)",
|
||||
"1:1 (1152 x 1152)",
|
||||
"4:3 (1536 x 1152)",
|
||||
"3:4 (1152 x 1536)",
|
||||
"21:9 (2560 x 1080)",
|
||||
],
|
||||
default="16:9 (1920 x 1080)",
|
||||
tooltip="Resolution of the output video",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"prompt_adherence",
|
||||
default=4.0,
|
||||
min=1.0,
|
||||
max=20.0,
|
||||
step=1.0,
|
||||
tooltip="Guidance scale for generation control",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=9,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed value",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Inference steps",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(),
|
||||
expr="""{"type":"usd","usd": 1.5}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
prompt_adherence: float,
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
guidance_scale=prompt_adherence,
|
||||
num_frames=128,
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
)
|
||||
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||
)
|
||||
validate_task_creation_response(task_creation_response)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
MoonvalleyImg2VideoNode,
|
||||
MoonvalleyTxt2VideoNode,
|
||||
MoonvalleyVideo2VideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MoonvalleyExtension:
|
||||
return MoonvalleyExtension()
|
||||
@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||
|
||||
|
||||
class SupportedOpenAIModel(str, Enum):
|
||||
o4_mini = "o4-mini"
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
gpt_5_5_pro = "gpt-5.5-pro"
|
||||
gpt_5_5 = "gpt-5.5"
|
||||
gpt_5 = "gpt-5"
|
||||
gpt_5_mini = "gpt-5-mini"
|
||||
gpt_5_nano = "gpt-5-nano"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
o4_mini = "o4-mini"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
o1 = "o1"
|
||||
|
||||
|
||||
async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
|
||||
@ -415,8 +417,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
"1152x2048",
|
||||
"3840x2160",
|
||||
"2160x3840",
|
||||
"Custom",
|
||||
],
|
||||
tooltip="Image size",
|
||||
tooltip="Image size. Select 'Custom' to use the custom width and height (GPT Image 2 only).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -445,6 +448,24 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
default="gpt-image-2",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_width",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"custom_height",
|
||||
default=1024,
|
||||
min=1024,
|
||||
max=3840,
|
||||
step=16,
|
||||
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -471,9 +492,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
"high": [0.133, 0.22]
|
||||
},
|
||||
"gpt-image-2": {
|
||||
"low": [0.0048, 0.012],
|
||||
"medium": [0.041, 0.112],
|
||||
"high": [0.165, 0.43]
|
||||
"low": [0.0048, 0.019],
|
||||
"medium": [0.041, 0.168],
|
||||
"high": [0.165, 0.67]
|
||||
}
|
||||
};
|
||||
$range := $lookup($lookup($ranges, widgets.model), widgets.quality);
|
||||
@ -503,6 +524,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
mask: Input.Image | None = None,
|
||||
n: int = 1,
|
||||
size: str = "1024x1024",
|
||||
custom_width: int = 1024,
|
||||
custom_height: int = 1024,
|
||||
model: str = "gpt-image-1",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
@ -510,7 +533,25 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
if mask is not None and image is None:
|
||||
raise ValueError("Cannot use a mask without an input image")
|
||||
|
||||
if model in ("gpt-image-1", "gpt-image-1.5"):
|
||||
if size == "Custom":
|
||||
if model != "gpt-image-2":
|
||||
raise ValueError("Custom resolution is only supported by GPT Image 2 model")
|
||||
if custom_width % 16 != 0 or custom_height % 16 != 0:
|
||||
raise ValueError(f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}")
|
||||
if max(custom_width, custom_height) > 3840:
|
||||
raise ValueError(f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}")
|
||||
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
|
||||
if ratio > 3:
|
||||
raise ValueError(
|
||||
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
|
||||
)
|
||||
total_pixels = custom_width * custom_height
|
||||
if not 655_360 <= total_pixels <= 8_294_400:
|
||||
raise ValueError(
|
||||
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
|
||||
)
|
||||
size = f"{custom_width}x{custom_height}"
|
||||
elif model in ("gpt-image-1", "gpt-image-1.5"):
|
||||
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
|
||||
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
|
||||
|
||||
@ -700,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
"usd": [0.002, 0.008],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5.5-pro") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.03, 0.18],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5.5") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.005, 0.03],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-5-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00005, 0.0004],
|
||||
|
||||
@ -33,7 +33,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIVideoSora2",
|
||||
display_name="OpenAI Sora - Video (Deprecated)",
|
||||
display_name="OpenAI Sora - Video (DEPRECATED)",
|
||||
category="api node/video/Sora",
|
||||
description=(
|
||||
"OpenAI video and audio generation.\n\n"
|
||||
|
||||
@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Astra 2": "ast-2",
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
"Starlight Precise 2.5": "slp-2.5",
|
||||
}
|
||||
|
||||
AST2_MAX_FRAMES = 9000
|
||||
AST2_MAX_FRAMES_WITH_PROMPT = 450
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
"Starlight (Astra) Fast",
|
||||
"Starlight (Astra) Creative",
|
||||
"Starlight Precise 2.5",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -453,7 +465,350 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.DynamicCombo.Input(
|
||||
"upscaler_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Astra 2",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Float.Input(
|
||||
"creativity",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional descriptive (not instructive) scene prompt."
|
||||
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"sharp",
|
||||
default=0.5,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pre-enhance sharpness: "
|
||||
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"realism",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Pulls output toward photographic realism."
|
||||
"Leave at 0 for the model default.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Fast",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight (Astra) Creative",
|
||||
[
|
||||
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creative strength of the upscale.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Starlight Precise 2.5",
|
||||
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
|
||||
),
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"interpolation_model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"apo-8",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=[
|
||||
"upscaler_model",
|
||||
"upscaler_model.upscaler_resolution",
|
||||
"interpolation_model",
|
||||
]),
|
||||
expr="""
|
||||
(
|
||||
$model := $lookup(widgets, "upscaler_model");
|
||||
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
|
||||
$interp := $lookup(widgets, "interpolation_model");
|
||||
$is4k := $contains($res, "4k");
|
||||
$hasInterp := $interp != "disabled";
|
||||
$rates := {
|
||||
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
|
||||
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
|
||||
"astra 2": {"hd": 1.72, "uhd": 2.85},
|
||||
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
|
||||
};
|
||||
$surcharge := $is4k ? 0.28 : 0.14;
|
||||
$entry := $lookup($rates, $model);
|
||||
$base := $is4k ? $entry.uhd : $entry.hd;
|
||||
$hi := $base + ($hasInterp ? $surcharge : 0);
|
||||
$model = "disabled"
|
||||
? {"type":"text","text":"Interpolation only"}
|
||||
: ($hasInterp
|
||||
? {"type":"text","text":"~" & $string($base) & "–" & $string($hi) & " credits/src frame"}
|
||||
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
upscaler_model: dict,
|
||||
interpolation_model: dict,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
upscaler_choice = upscaler_model["upscaler_model"]
|
||||
interpolation_choice = interpolation_model["interpolation_model"]
|
||||
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
validate_container_format_is_mp4(video)
|
||||
src_width, src_height = video.get_dimensions()
|
||||
src_frame_rate = int(video.get_frame_rate())
|
||||
duration_sec = video.get_duration()
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_choice != "Disabled":
|
||||
if "1080p" in upscaler_model["upscaler_resolution"]:
|
||||
target_pixel_p = 1080
|
||||
max_long_side = 1920
|
||||
else:
|
||||
target_pixel_p = 2160
|
||||
max_long_side = 3840
|
||||
ar = src_width / src_height
|
||||
if src_width >= src_height:
|
||||
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
|
||||
target_height = target_pixel_p
|
||||
target_width = int(target_height * ar)
|
||||
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
|
||||
if target_width > max_long_side:
|
||||
target_width = max_long_side
|
||||
target_height = int(target_width / ar)
|
||||
else:
|
||||
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
|
||||
target_width = target_pixel_p
|
||||
target_height = int(target_width / ar)
|
||||
# Check if height exceeds standard bounds
|
||||
if target_height > max_long_side:
|
||||
target_height = max_long_side
|
||||
target_width = int(target_height * ar)
|
||||
if target_width % 2 != 0:
|
||||
target_width += 1
|
||||
if target_height % 2 != 0:
|
||||
target_height += 1
|
||||
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
|
||||
if model_id == "slc-1":
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
isOptimizedMode=True,
|
||||
)
|
||||
)
|
||||
elif model_id == "ast-2":
|
||||
n_frames = video.get_frame_count()
|
||||
ast2_prompt = (upscaler_model["prompt"] or "").strip()
|
||||
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
|
||||
raise ValueError(
|
||||
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
|
||||
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
|
||||
)
|
||||
if n_frames > AST2_MAX_FRAMES:
|
||||
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
|
||||
realism = upscaler_model["realism"]
|
||||
filters.append(
|
||||
VideoEnhancementFilter(
|
||||
model=model_id,
|
||||
creativity=upscaler_model["creativity"],
|
||||
prompt=(ast2_prompt or None),
|
||||
sharp=upscaler_model["sharp"],
|
||||
realism=(realism if realism > 0 else None),
|
||||
)
|
||||
)
|
||||
else:
|
||||
filters.append(VideoEnhancementFilter(model=model_id))
|
||||
if interpolation_choice != "Disabled":
|
||||
target_frame_rate = interpolation_model["interpolation_frame_rate"]
|
||||
filters.append(
|
||||
VideoFrameInterpolationFilter(
|
||||
model=interpolation_choice,
|
||||
slowmo=interpolation_model["interpolation_slowmo"],
|
||||
fps=interpolation_model["interpolation_frame_rate"],
|
||||
duplicate=interpolation_model["interpolation_duplicate"],
|
||||
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=CreateVideoResponse,
|
||||
data=CreateVideoRequest(
|
||||
source=CreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=OutputInformationVideo(
|
||||
resolution=Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=VideoCompleteUploadResponse,
|
||||
data=VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
@ -464,6 +819,7 @@ class TopazExtension(ComfyExtension):
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
TopazVideoEnhanceV2,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest,
|
||||
max_poll_attempts: int = 320,
|
||||
max_poll_attempts: int = 480,
|
||||
) -> list[TaskResult]:
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
@ -1097,7 +1097,6 @@ class ViduExtendVideoNode(IO.ComfyNode):
|
||||
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"),
|
||||
images=[image_url] if image_url else None,
|
||||
),
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
@ -818,7 +818,6 @@ class WanReferenceVideoApi(IO.ComfyNode):
|
||||
response_model=VideoTaskStatusResponse,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
poll_interval=6,
|
||||
max_poll_attempts=280,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||
|
||||
|
||||
@ -84,7 +84,6 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
@ -156,7 +155,6 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
response_model=TaskResultResponse,
|
||||
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
if final_response.code != 200:
|
||||
raise ValueError(
|
||||
|
||||
@ -19,6 +19,8 @@ from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
@ -148,7 +150,7 @@ async def poll_op(
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
max_poll_attempts: int = 480,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
@ -254,7 +256,7 @@ async def poll_op_raw(
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
max_poll_attempts: int = 480,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
@ -624,6 +626,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
|
||||
@ -199,6 +199,9 @@ class FILMNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 1700 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
"""Pre-compute warp grids for all pyramid levels."""
|
||||
if (H, W) in self._warp_grids:
|
||||
|
||||
@ -74,6 +74,9 @@ class IFNet(nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.encode.cnn0.weight.dtype
|
||||
|
||||
def memory_used_forward(self, shape, dtype):
|
||||
return 300 * shape[1] * shape[2] * dtype.itemsize
|
||||
|
||||
def _build_warp_grids(self, H, W, device):
|
||||
if (H, W) in self._warp_grids:
|
||||
return
|
||||
|
||||
@ -42,7 +42,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
IO.Int.Input("bpm", default=120, min=10, max=300),
|
||||
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
|
||||
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
|
||||
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
|
||||
IO.Combo.Input("language", options=['ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id', 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no', 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh', 'unknown'], default='en'),
|
||||
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
|
||||
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
|
||||
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),
|
||||
|
||||
84
comfy_extras/nodes_ar_video.py
Normal file
84
comfy_extras/nodes_ar_video.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class EmptyARVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyARVideoLatent",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class SamplerARVideo(io.ComfyNode):
|
||||
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
|
||||
|
||||
All AR-loop parameters are owned by this node so they live in the workflow.
|
||||
Add new widgets here as the AR sampler grows new options.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerARVideo",
|
||||
display_name="Sampler AR Video",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"num_frame_per_block",
|
||||
default=1, min=1, max=64,
|
||||
tooltip="Frames per autoregressive block. 1 = framewise, "
|
||||
"3 = chunkwise. Must match the checkpoint's training mode.",
|
||||
),
|
||||
],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, num_frame_per_block) -> io.NodeOutput:
|
||||
extra_options = {
|
||||
"num_frame_per_block": num_frame_per_block,
|
||||
}
|
||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||
|
||||
|
||||
class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyARVideoLatent,
|
||||
SamplerARVideo,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ARVideoExtension:
|
||||
return ARVideoExtension()
|
||||
@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = min(len(image), len(alpha))
|
||||
out_images = []
|
||||
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
for i in range(batch_size):
|
||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||
|
||||
return io.NodeOutput(torch.stack(out_images))
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
|
||||
class CompositingExtension(ComfyExtension):
|
||||
|
||||
@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
split_conds_to_windows=split_conds_to_windows,
|
||||
causal_window_fix=causal_window_fix,
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
|
||||
@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
model = cls._detect_and_load(sd)
|
||||
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
|
||||
model.eval().to(dtype)
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||
model,
|
||||
load_device=model_management.get_torch_device(),
|
||||
offload_device=model_management.unet_offload_device(),
|
||||
@ -78,7 +78,7 @@ class FrameInterpolate(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolate",
|
||||
display_name="Frame Interpolate",
|
||||
category="image/video",
|
||||
category="video",
|
||||
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
||||
inputs=[
|
||||
FrameInterpolationModel.Input("interp_model"),
|
||||
@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode):
|
||||
if num_frames < 2 or multiplier < 2:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
model_management.load_model_gpu(interp_model)
|
||||
device = interp_model.load_device
|
||||
dtype = interp_model.model_dtype()
|
||||
inference_model = interp_model.model
|
||||
|
||||
# Free VRAM for inference activations (model weights + ~20x a single frame's worth)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
activation_mem = H * W * 3 * images.element_size() * 20
|
||||
model_management.free_memory(activation_mem, device)
|
||||
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
|
||||
model_management.load_models_gpu([interp_model], memory_required=activation_mem)
|
||||
align = getattr(inference_model, "pad_align", 1)
|
||||
H, W = images.shape[1], images.shape[2]
|
||||
|
||||
# Prepare a single padded frame on device for determining output dimensions
|
||||
def prepare_frame(idx):
|
||||
|
||||
@ -11,7 +11,7 @@ class ImageCompare(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompare",
|
||||
display_name="Image Compare",
|
||||
display_name="Compare Images",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
essentials_category="Image Tools",
|
||||
|
||||
@ -24,7 +24,7 @@ class ImageCrop(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop (Deprecated)",
|
||||
display_name="Crop Image (DEPRECATED)",
|
||||
category="image/transform",
|
||||
is_deprecated=True,
|
||||
essentials_category="Image Tools",
|
||||
@ -56,7 +56,7 @@ class ImageCropV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
display_name="Crop Image",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
has_intermediate_output=True,
|
||||
@ -109,6 +109,7 @@ class RepeatImageBatch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RepeatImageBatch",
|
||||
search_aliases=["duplicate image", "clone image"],
|
||||
display_name="Repeat Image Batch",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -131,6 +132,7 @@ class ImageFromBatch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageFromBatch",
|
||||
search_aliases=["select image", "pick from batch", "extract image"],
|
||||
display_name="Get Image from Batch",
|
||||
category="image/batch",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -157,7 +159,8 @@ class ImageAddNoise(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageAddNoise",
|
||||
search_aliases=["film grain"],
|
||||
category="image",
|
||||
display_name="Add Noise to Image",
|
||||
category="image/postprocessing",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input(
|
||||
@ -259,7 +262,7 @@ class ImageStitch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageStitch",
|
||||
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
||||
display_name="Image Stitch",
|
||||
display_name="Stitch Images",
|
||||
description="Stitches image2 to image1 in the specified direction.\n"
|
||||
"If image2 is not provided, returns image1 unchanged.\n"
|
||||
"Optional spacing can be added between images.",
|
||||
@ -434,6 +437,7 @@ class ResizeAndPadImage(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ResizeAndPadImage",
|
||||
search_aliases=["fit to size"],
|
||||
display_name="Resize And Pad Image",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -485,6 +489,7 @@ class SaveSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveSVGNode",
|
||||
search_aliases=["export vector", "save vector graphics"],
|
||||
display_name="Save SVG",
|
||||
description="Save SVG files on disk.",
|
||||
category="image/save",
|
||||
inputs=[
|
||||
@ -591,7 +596,7 @@ class ImageRotate(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageRotate",
|
||||
display_name="Image Rotate",
|
||||
display_name="Rotate Image",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
@ -624,6 +629,7 @@ class ImageFlip(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageFlip",
|
||||
search_aliases=["mirror", "reflect"],
|
||||
display_name="Flip Image",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -650,6 +656,7 @@ class ImageScaleToMaxDimension(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageScaleToMaxDimension",
|
||||
display_name="Scale Image to Max Dimension",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -709,7 +716,7 @@ class SplitImageToTileList(IO.ComfyNode):
|
||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
||||
coords = []
|
||||
stride_x = round(max(tile_width * 0.25, tile_width - overlap))
|
||||
stride_y = round(max(tile_width * 0.25, tile_height - overlap))
|
||||
stride_y = round(max(tile_height * 0.25, tile_height - overlap))
|
||||
|
||||
y = 0
|
||||
while y < height:
|
||||
|
||||
@ -147,7 +147,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
|
||||
z_channels = audio_vae.latent_channels
|
||||
audio_freq = audio_vae.first_stage_model.latent_frequency_bins
|
||||
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
|
||||
|
||||
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
|
||||
|
||||
@ -159,7 +158,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
return io.NodeOutput(
|
||||
{
|
||||
"samples": audio_latents,
|
||||
"sample_rate": sampling_rate,
|
||||
"type": "audio",
|
||||
}
|
||||
)
|
||||
|
||||
@ -80,7 +80,8 @@ class ImageCompositeMasked(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompositeMasked",
|
||||
search_aliases=["paste image", "overlay", "layer"],
|
||||
search_aliases=["overlay", "layer", "paste image", "images composition"],
|
||||
display_name="Image Composite Masked",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
@ -201,6 +202,7 @@ class InvertMask(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="InvertMask",
|
||||
search_aliases=["reverse mask", "flip mask"],
|
||||
display_name="Invert Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -222,6 +224,7 @@ class CropMask(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="CropMask",
|
||||
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
||||
display_name="Crop Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
@ -247,7 +250,8 @@ class MaskComposite(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskComposite",
|
||||
search_aliases=["combine masks", "blend masks", "layer masks"],
|
||||
search_aliases=["combine masks", "blend masks", "layer masks", "masks composition"],
|
||||
display_name="Combine Masks",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("destination"),
|
||||
@ -298,6 +302,7 @@ class FeatherMask(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FeatherMask",
|
||||
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
||||
display_name="Feather Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
|
||||
@ -59,7 +59,8 @@ class ImageRGBToYUV(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ImageRGBToYUV",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
display_name="Image RGB to YUV",
|
||||
category="image/color",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
@ -81,7 +82,8 @@ class ImageYUVToRGB(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ImageYUVToRGB",
|
||||
search_aliases=["color space conversion"],
|
||||
category="image/batch",
|
||||
display_name="Image YUV to RGB",
|
||||
category="image/color",
|
||||
inputs=[
|
||||
io.Image.Input("Y"),
|
||||
io.Image.Input("U"),
|
||||
|
||||
@ -20,7 +20,8 @@ class Blend(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageBlend",
|
||||
display_name="Image Blend",
|
||||
search_aliases=["mix images"],
|
||||
display_name="Blend Images",
|
||||
category="image/postprocessing",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
@ -224,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ImageScaleToTotalPixels",
|
||||
display_name="Scale Image to Total Pixels",
|
||||
category="image/upscaling",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
@ -568,7 +570,7 @@ class BatchImagesNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="BatchImagesNode",
|
||||
display_name="Batch Images",
|
||||
category="image",
|
||||
category="image/batch",
|
||||
essentials_category="Image Tools",
|
||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||
inputs=[
|
||||
@ -666,12 +668,13 @@ class ColorTransfer(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ColorTransfer",
|
||||
display_name="Color Transfer",
|
||||
category="image/postprocessing",
|
||||
description="Match the colors of one image to another using various algorithms.",
|
||||
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
|
||||
inputs=[
|
||||
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
|
||||
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"),
|
||||
io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
|
||||
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
|
||||
io.DynamicCombo.Input("source_stats",
|
||||
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",
|
||||
|
||||
@ -9,7 +9,8 @@ class String(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveString",
|
||||
display_name="String",
|
||||
search_aliases=["text", "string", "text box", "prompt"],
|
||||
display_name="Text String",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.String.Input("value"),
|
||||
@ -27,7 +28,8 @@ class StringMultiline(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PrimitiveStringMultiline",
|
||||
display_name="String (Multiline)",
|
||||
search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
|
||||
display_name="Text String (Multiline)",
|
||||
category="utils/primitive",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
@ -49,7 +51,7 @@ class Int(io.ComfyNode):
|
||||
display_name="Int",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
|
||||
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
|
||||
],
|
||||
outputs=[io.Int.Output()],
|
||||
)
|
||||
|
||||
@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
total_images = image.shape[0]
|
||||
captured_feat = None
|
||||
|
||||
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||
model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768
|
||||
model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024
|
||||
|
||||
def _resize_to_model(imgs):
|
||||
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
|
||||
"""Stretch BHWC images to (model_h, model_w), model expects no aspect preservation."""
|
||||
h, w = imgs.shape[-3], imgs.shape[-2]
|
||||
scale = min(model_h / h, model_w / w)
|
||||
sh, sw = int(round(h * scale)), int(round(w * scale))
|
||||
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
|
||||
method = "area" if (model_h <= h and model_w <= w) else "bilinear"
|
||||
chw = imgs.permute(0, 3, 1, 2).float()
|
||||
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
|
||||
return padded.permute(0, 2, 3, 1), scale, pt, pl
|
||||
scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled")
|
||||
return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h
|
||||
|
||||
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
|
||||
def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0):
|
||||
"""Remap keypoints from model space back to original image space."""
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
invalid = kp[..., 0] < 0
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
|
||||
kp[..., 0] = kp[..., 0] / scale_x + offset_x
|
||||
kp[..., 1] = kp[..., 1] / scale_y + offset_y
|
||||
kp[invalid] = -1
|
||||
return kp
|
||||
|
||||
@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
continue
|
||||
|
||||
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
|
||||
crop_resized, sx, sy = _resize_to_model(crop)
|
||||
|
||||
latent_crop = vae.encode(crop_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
|
||||
kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1)
|
||||
img_keypoints.append(kp)
|
||||
img_scores.append(sc_batch[0])
|
||||
else:
|
||||
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
|
||||
img_resized, sx, sy = _resize_to_model(img)
|
||||
latent_img = vae.encode(img_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy))
|
||||
img_scores.append(sc_batch[0])
|
||||
|
||||
all_keypoints.append(img_keypoints)
|
||||
@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
|
||||
else: # full-image mode, batched
|
||||
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
|
||||
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
latent_batch = vae.encode(batch_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||
|
||||
for kp, sc in zip(kp_batch, sc_batch):
|
||||
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
|
||||
all_keypoints.append([_remap_keypoints(kp, sx, sy)])
|
||||
all_scores.append([sc])
|
||||
|
||||
pbar.update(len(kp_batch))
|
||||
@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode):
|
||||
scale = min(output_width / crop_w, output_height / crop_h)
|
||||
scaled_w = int(round(crop_w * scale))
|
||||
scaled_h = int(round(crop_h * scale))
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled")
|
||||
pad_left = (output_width - scaled_w) // 2
|
||||
pad_top = (output_height - scaled_h) // 2
|
||||
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
else: # "stretch"
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled")
|
||||
crops.append(resized)
|
||||
|
||||
if not crops:
|
||||
|
||||
@ -10,9 +10,9 @@ class StringConcatenate(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringConcatenate",
|
||||
display_name="Text Concatenate",
|
||||
category="utils/string",
|
||||
search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
|
||||
search_aliases=["concatenate", "text concat", "join text", "merge text", "combine strings", "string concat", "append text", "combine text"],
|
||||
display_name="Concatenate Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
@ -33,9 +33,9 @@ class StringSubstring(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringSubstring",
|
||||
search_aliases=["Substring", "extract text", "text portion"],
|
||||
display_name="Text Substring",
|
||||
category="utils/string",
|
||||
search_aliases=["substring", "extract text", "text portion"],
|
||||
display_name="Substring",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Int.Input("start"),
|
||||
@ -58,7 +58,7 @@ class StringLength(io.ComfyNode):
|
||||
node_id="StringLength",
|
||||
search_aliases=["character count", "text size", "string length"],
|
||||
display_name="Text Length",
|
||||
category="utils/string",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
],
|
||||
@ -77,9 +77,9 @@ class CaseConverter(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CaseConverter",
|
||||
search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Text Case Converter",
|
||||
category="utils/string",
|
||||
search_aliases=["case converter", "text case", "uppercase", "lowercase", "capitalize"],
|
||||
display_name="Convert Text Case",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
|
||||
@ -110,9 +110,9 @@ class StringTrim(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringTrim",
|
||||
search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"],
|
||||
display_name="Text Trim",
|
||||
category="utils/string",
|
||||
search_aliases=["trim", "clean whitespace", "remove whitespace", "remove spaces","strip"],
|
||||
display_name="Trim Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.Combo.Input("mode", options=["Both", "Left", "Right"]),
|
||||
@ -141,9 +141,9 @@ class StringReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringReplace",
|
||||
search_aliases=["Replace", "find and replace", "substitute", "swap text"],
|
||||
display_name="Text Replace",
|
||||
category="utils/string",
|
||||
search_aliases=["replace", "find and replace", "substitute", "swap text"],
|
||||
display_name="Replace Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("find", multiline=True),
|
||||
@ -164,9 +164,9 @@ class StringContains(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringContains",
|
||||
search_aliases=["Contains", "text includes", "string includes"],
|
||||
display_name="Text Contains",
|
||||
category="utils/string",
|
||||
search_aliases=["contains", "text includes", "string includes"],
|
||||
display_name="Contains Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("substring", multiline=True),
|
||||
@ -192,9 +192,9 @@ class StringCompare(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="StringCompare",
|
||||
search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Text Compare",
|
||||
category="utils/string",
|
||||
search_aliases=["compare", "text match", "string equals", "starts with", "ends with"],
|
||||
display_name="Compare Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string_a", multiline=True),
|
||||
io.String.Input("string_b", multiline=True),
|
||||
@ -228,9 +228,9 @@ class RegexMatch(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexMatch",
|
||||
search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"],
|
||||
display_name="Text Match",
|
||||
category="utils/string",
|
||||
search_aliases=["regex match", "regex", "pattern match", "text contains", "string match"],
|
||||
display_name="Match Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
@ -269,9 +269,9 @@ class RegexExtract(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexExtract",
|
||||
search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"],
|
||||
display_name="Text Extract Substring",
|
||||
category="utils/string",
|
||||
search_aliases=["regex extract", "regex", "pattern extract", "text parser", "parse text"],
|
||||
display_name="Extract Text",
|
||||
category="text",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
io.String.Input("regex_pattern", multiline=True),
|
||||
@ -344,9 +344,9 @@ class RegexReplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RegexReplace",
|
||||
search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"],
|
||||
display_name="Text Replace (Regex)",
|
||||
category="utils/string",
|
||||
search_aliases=["regex replace", "regex", "pattern replace", "substitution"],
|
||||
display_name="Replace Text (Regex)",
|
||||
category="text",
|
||||
description="Find and replace text using regex patterns.",
|
||||
inputs=[
|
||||
io.String.Input("string", multiline=True),
|
||||
@ -381,8 +381,8 @@ class JsonExtractString(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="JsonExtractString",
|
||||
display_name="Extract String from JSON",
|
||||
category="utils/string",
|
||||
display_name="Extract Text from JSON",
|
||||
category="text",
|
||||
search_aliases=["json", "extract json", "parse json", "json value", "read json"],
|
||||
inputs=[
|
||||
io.String.Input("json_string", multiline=True),
|
||||
|
||||
@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
|
||||
io.Audio.Input("audio", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
|
||||
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
|
||||
seed=seed
|
||||
)
|
||||
|
||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
||||
generated_text = clip.decode(generated_ids)
|
||||
|
||||
return io.NodeOutput(generated_text)
|
||||
|
||||
|
||||
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
@ -17,7 +17,8 @@ class SaveWEBM(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SaveWEBM",
|
||||
search_aliases=["export webm"],
|
||||
category="image/video",
|
||||
display_name="Save WEBM",
|
||||
category="video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
@ -72,7 +73,7 @@ class SaveVideo(io.ComfyNode):
|
||||
node_id="SaveVideo",
|
||||
search_aliases=["export video"],
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
category="video",
|
||||
essentials_category="Basics",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
@ -121,7 +122,7 @@ class CreateVideo(io.ComfyNode):
|
||||
node_id="CreateVideo",
|
||||
search_aliases=["images to video"],
|
||||
display_name="Create Video",
|
||||
category="image/video",
|
||||
category="video",
|
||||
description="Create a video from images.",
|
||||
inputs=[
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
@ -146,7 +147,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
node_id="GetVideoComponents",
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
category="video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
@ -174,7 +175,7 @@ class LoadVideo(io.ComfyNode):
|
||||
node_id="LoadVideo",
|
||||
search_aliases=["import video", "open video", "video file"],
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
category="video",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
||||
@ -216,7 +217,7 @@ class VideoSlice(io.ComfyNode):
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
category="video",
|
||||
essentials_category="Video Tools",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
|
||||
483
comfy_extras/nodes_void.py
Normal file
483
comfy_extras/nodes_void.py
Normal file
@ -0,0 +1,483 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
import comfy
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy.utils import model_trange as trange
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from torchvision.models.optical_flow import raft_large
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
|
||||
|
||||
OpticalFlow = io.Custom("OPTICAL_FLOW")
|
||||
|
||||
TEMPORAL_COMPRESSION = 4
|
||||
PATCH_SIZE_T = 2
|
||||
|
||||
|
||||
def _valid_void_length(length: int) -> int:
|
||||
"""Round ``length`` down to a value that produces an even latent_t.
|
||||
|
||||
VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent
|
||||
must have an even temporal dimension. If latent_t is odd, the transformer
|
||||
pad_to_patch_size circular-wraps an extra latent frame onto the end; after
|
||||
the post-transformer crop the last real latent frame has been influenced
|
||||
by the wrapped phantom frame, producing visible jitter and "disappearing"
|
||||
subjects near the end of the decoded video. Rounding down fixes this.
|
||||
"""
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
if latent_t % PATCH_SIZE_T == 0:
|
||||
return length
|
||||
# Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert
|
||||
# the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame
|
||||
# so we never return a non-positive length.
|
||||
target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T)
|
||||
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
|
||||
|
||||
|
||||
class OpticalFlowLoader(io.ComfyNode):
|
||||
"""Load an optical flow model from ``models/optical_flow/``.
|
||||
|
||||
Only torchvision's RAFT-large format is recognized today (the model used
|
||||
by VOIDWarpedNoise). The checkpoint must be placed under
|
||||
``models/optical_flow/`` — ComfyUI never downloads optical-flow weights
|
||||
at runtime.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="OpticalFlowLoader",
|
||||
display_name="Load Optical Flow Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("optical_flow"),
|
||||
tooltip=(
|
||||
"Optical flow model to load. Files must be placed in the "
|
||||
"'optical_flow' folder. Today only torchvision's "
|
||||
"raft_large.pth is supported."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
OpticalFlow.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
|
||||
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
|
||||
has_raft_keys = (
|
||||
any(k.startswith("feature_encoder.") for k in sd)
|
||||
and any(k.startswith("context_encoder.") for k in sd)
|
||||
and any(k.startswith("update_block.") for k in sd)
|
||||
)
|
||||
if not has_raft_keys:
|
||||
raise ValueError(
|
||||
"Unrecognized optical flow model format: expected a torchvision "
|
||||
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
|
||||
"and 'update_block.' prefixes."
|
||||
)
|
||||
|
||||
model = raft_large(weights=None, progress=False)
|
||||
model.load_state_dict(sd)
|
||||
model.eval().to(torch.float32)
|
||||
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
model,
|
||||
load_device=comfy.model_management.get_torch_device(),
|
||||
offload_device=comfy.model_management.unet_offload_device(),
|
||||
)
|
||||
return io.NodeOutput(patcher)
|
||||
|
||||
|
||||
class VOIDQuadmaskPreprocess(io.ComfyNode):
|
||||
"""Preprocess a quadmask video for VOID inpainting.
|
||||
|
||||
Quantizes mask values to four semantic levels, inverts, and normalizes:
|
||||
0 -> primary object to remove
|
||||
63 -> overlap of primary + affected
|
||||
127 -> affected region (interactions)
|
||||
255 -> background (keep)
|
||||
|
||||
After inversion and normalization, the output mask has values in [0, 1]
|
||||
with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDQuadmaskPreprocess",
|
||||
category="mask/video",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Int.Input("dilate_width", default=0, min=0, max=50, step=1,
|
||||
tooltip="Dilation radius for the primary mask region (0 = no dilation)"),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output(display_name="quadmask"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, dilate_width=0) -> io.NodeOutput:
|
||||
m = mask.clone()
|
||||
|
||||
if m.max() <= 1.0:
|
||||
m = m * 255.0
|
||||
|
||||
if dilate_width > 0 and m.ndim >= 3:
|
||||
binary = (m < 128).float()
|
||||
kernel_size = dilate_width * 2 + 1
|
||||
if binary.ndim == 3:
|
||||
binary = binary.unsqueeze(1)
|
||||
dilated = torch.nn.functional.max_pool2d(
|
||||
binary, kernel_size=kernel_size, stride=1, padding=dilate_width
|
||||
)
|
||||
if dilated.ndim == 4:
|
||||
dilated = dilated.squeeze(1)
|
||||
m = torch.where(dilated > 0.5, torch.zeros_like(m), m)
|
||||
|
||||
m = torch.where(m <= 31, torch.zeros_like(m), m)
|
||||
m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m)
|
||||
m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m)
|
||||
m = torch.where(m > 191, torch.full_like(m, 255), m)
|
||||
|
||||
m = (255.0 - m) / 255.0
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class VOIDInpaintConditioning(io.ComfyNode):
|
||||
"""Build VOID inpainting conditioning for CogVideoX.
|
||||
|
||||
Encodes the processed quadmask and masked source video through the VAE,
|
||||
producing a 32-channel concat conditioning (16ch mask + 16ch masked video)
|
||||
that gets concatenated with the 16ch noise latent by the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDInpaintConditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"),
|
||||
io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"),
|
||||
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
|
||||
tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 "
|
||||
"(patch_size_t=2), latent_t must be even — lengths that "
|
||||
"produce odd latent_t are rounded down (e.g. 49 → 45)."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, video, quadmask,
|
||||
width, height, length, batch_size) -> io.NodeOutput:
|
||||
|
||||
adjusted_length = _valid_void_length(length)
|
||||
if adjusted_length != length:
|
||||
logging.warning(
|
||||
"VOIDInpaintConditioning: rounding length %d down to %d so that "
|
||||
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). "
|
||||
"Using odd latent_t causes the last frame to be corrupted by "
|
||||
"circular padding.", length, adjusted_length,
|
||||
)
|
||||
length = adjusted_length
|
||||
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
vid = video[:length]
|
||||
vid = comfy.utils.common_upscale(
|
||||
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
qm = quadmask[:length]
|
||||
if qm.ndim == 3:
|
||||
qm = qm.unsqueeze(-1)
|
||||
qm = comfy.utils.common_upscale(
|
||||
qm.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
if qm.ndim == 4 and qm.shape[-1] == 1:
|
||||
qm = qm.squeeze(-1)
|
||||
|
||||
mask_condition = qm
|
||||
if mask_condition.ndim == 3:
|
||||
mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3)
|
||||
else:
|
||||
mask_condition_3ch = mask_condition
|
||||
|
||||
inverted_mask_3ch = 1.0 - mask_condition_3ch
|
||||
masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch)
|
||||
|
||||
mask_latents = vae.encode(inverted_mask_3ch)
|
||||
masked_video_latents = vae.encode(masked_video)
|
||||
|
||||
def _match_temporal(lat, target_t):
|
||||
if lat.shape[2] > target_t:
|
||||
return lat[:, :, :target_t]
|
||||
elif lat.shape[2] < target_t:
|
||||
pad = target_t - lat.shape[2]
|
||||
return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2)
|
||||
return lat
|
||||
|
||||
mask_latents = _match_temporal(mask_latents, latent_t)
|
||||
masked_video_latents = _match_temporal(masked_video_latents, latent_t)
|
||||
|
||||
inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1)
|
||||
|
||||
# No explicit scaling needed here: the model's CogVideoX.concat_cond()
|
||||
# applies process_latent_in (×latent_format.scale_factor) to each 16-ch
|
||||
# block of the stored conditioning. For 5b-class checkpoints (incl. the
|
||||
# VOID/CogVideoX-Fun-V1.5 inpainting model) that scale_factor is auto-
|
||||
# selected as 0.7 in supported_models.CogVideoX_T2V, which matches the
|
||||
# diffusers vae/config.json scaling_factor VOID was trained with.
|
||||
|
||||
positive = node_helpers.conditioning_set_values(
|
||||
positive, {"concat_latent_image": inpaint_latents}
|
||||
)
|
||||
negative = node_helpers.conditioning_set_values(
|
||||
negative, {"concat_latent_image": inpaint_latents}
|
||||
)
|
||||
|
||||
noise_latent = torch.zeros(
|
||||
[batch_size, 16, latent_t, latent_h, latent_w],
|
||||
device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noise_latent})
|
||||
|
||||
|
||||
class VOIDWarpedNoise(io.ComfyNode):
|
||||
"""Generate optical-flow warped noise for VOID Pass 2 refinement.
|
||||
|
||||
Takes the Pass 1 output video and produces temporally-correlated noise
|
||||
by warping Gaussian noise along optical flow vectors. This noise is used
|
||||
as the initial latent for Pass 2, resulting in better temporal consistency.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDWarpedNoise",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
OpticalFlow.Input(
|
||||
"optical_flow",
|
||||
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
|
||||
),
|
||||
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
|
||||
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
|
||||
tooltip="Number of pixel frames. Rounded down to make latent_t "
|
||||
"even (patch_size_t=2 requirement), e.g. 49 → 45."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="warped_noise"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
|
||||
|
||||
adjusted_length = _valid_void_length(length)
|
||||
if adjusted_length != length:
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: rounding length %d down to %d so that "
|
||||
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).",
|
||||
length, adjusted_length,
|
||||
)
|
||||
length = adjusted_length
|
||||
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
# RAFT + noise warp is real compute, not an "intermediate" buffer, so
|
||||
# we want the actual torch device (CUDA/MPS). The final latent is
|
||||
# moved back to intermediate_device() before returning to match the
|
||||
# rest of the ComfyUI pipeline.
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
comfy.model_management.load_model_gpu(optical_flow)
|
||||
raft = RaftOpticalFlow(optical_flow.model, device=device)
|
||||
|
||||
vid = video[:length].to(device)
|
||||
vid = comfy.utils.common_upscale(
|
||||
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8)
|
||||
|
||||
FRAME = 2**-1
|
||||
FLOW = 2**3
|
||||
LATENT_SCALE = 8
|
||||
|
||||
warped = get_noise_from_video(
|
||||
vid_uint8,
|
||||
raft,
|
||||
noise_channels=16,
|
||||
resize_frames=FRAME,
|
||||
resize_flow=FLOW,
|
||||
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if warped.shape[0] != latent_t:
|
||||
indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
|
||||
device=device).long()
|
||||
warped = warped[indices]
|
||||
|
||||
if warped.shape[1] != latent_h or warped.shape[2] != latent_w:
|
||||
# (T, H, W, C) → (T, C, H, W) → bilinear resize → back
|
||||
warped = warped.permute(0, 3, 1, 2)
|
||||
warped = torch.nn.functional.interpolate(
|
||||
warped, size=(latent_h, latent_w),
|
||||
mode="bilinear", align_corners=False,
|
||||
)
|
||||
warped = warped.permute(0, 2, 3, 1)
|
||||
|
||||
# (T, H, W, C) → (B, C, T, H, W)
|
||||
warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0)
|
||||
if batch_size > 1:
|
||||
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": warped_tensor})
|
||||
|
||||
|
||||
class Noise_FromLatent:
|
||||
"""Wraps a pre-computed LATENT tensor as a NOISE source."""
|
||||
def __init__(self, latent_dict):
|
||||
self.seed = 0
|
||||
self._samples = latent_dict["samples"]
|
||||
|
||||
def generate_noise(self, input_latent):
|
||||
return self._samples.clone().cpu()
|
||||
|
||||
|
||||
class VOIDWarpedNoiseSource(io.ComfyNode):
|
||||
"""Convert a LATENT (e.g. from VOIDWarpedNoise) into a NOISE source
|
||||
for use with SamplerCustomAdvanced."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDWarpedNoiseSource",
|
||||
category="sampling/custom_sampling/noise",
|
||||
inputs=[
|
||||
io.Latent.Input("warped_noise",
|
||||
tooltip="Warped noise latent from VOIDWarpedNoise"),
|
||||
],
|
||||
outputs=[io.Noise.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, warped_noise) -> io.NodeOutput:
|
||||
return io.NodeOutput(Noise_FromLatent(warped_noise))
|
||||
|
||||
|
||||
class VOID_DDIM(comfy.samplers.Sampler):
|
||||
"""DDIM sampler for VOID inpainting models.
|
||||
|
||||
VOID was trained with the diffusers CogVideoXDDIMScheduler which operates in
|
||||
alpha-space (input std ≈ 1). The standard KSampler applies noise_scaling that
|
||||
multiplies by sqrt(1+sigma^2) ≈ 4500x, which is incompatible with VOID's
|
||||
training. This sampler skips noise_scaling and implements the DDIM update rule
|
||||
directly using sigma-to-alpha conversion.
|
||||
"""
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
x = noise.to(torch.float32)
|
||||
model_options = extra_args.get("model_options", {})
|
||||
seed = extra_args.get("seed", None)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable_pbar):
|
||||
sigma = sigmas[i]
|
||||
sigma_next = sigmas[i + 1]
|
||||
|
||||
denoised = model_wrap(x, sigma * s_in, model_options=model_options, seed=seed)
|
||||
|
||||
if callback is not None:
|
||||
callback(i, denoised, x, len(sigmas) - 1)
|
||||
|
||||
if sigma_next == 0:
|
||||
x = denoised
|
||||
else:
|
||||
alpha_t = 1.0 / (1.0 + sigma ** 2)
|
||||
alpha_prev = 1.0 / (1.0 + sigma_next ** 2)
|
||||
|
||||
pred_eps = (x - (alpha_t ** 0.5) * denoised) / (1.0 - alpha_t) ** 0.5
|
||||
x = (alpha_prev ** 0.5) * denoised + (1.0 - alpha_prev) ** 0.5 * pred_eps
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VOIDSampler(io.ComfyNode):
|
||||
"""VOID DDIM sampler for use with SamplerCustom / SamplerCustomAdvanced.
|
||||
|
||||
Required for VOID inpainting models. Implements the same DDIM loop that VOID
|
||||
was trained with (diffusers CogVideoXDDIMScheduler), without the noise_scaling
|
||||
that the standard KSampler applies. Use with RandomNoise or VOIDWarpedNoiseSource.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDSampler",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls) -> io.NodeOutput:
|
||||
return io.NodeOutput(VOID_DDIM())
|
||||
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class VOIDExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
OpticalFlowLoader,
|
||||
VOIDQuadmaskPreprocess,
|
||||
VOIDInpaintConditioning,
|
||||
VOIDWarpedNoise,
|
||||
VOIDWarpedNoiseSource,
|
||||
VOIDSampler,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> VOIDExtension:
|
||||
return VOIDExtension()
|
||||
494
comfy_extras/void_noise_warp.py
Normal file
494
comfy_extras/void_noise_warp.py
Normal file
@ -0,0 +1,494 @@
|
||||
"""
|
||||
Optical-flow-warped noise for VOID Pass 2 refinement.
|
||||
|
||||
Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
|
||||
https://github.com/RyannDaGreat/CommonSource
|
||||
- noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video)
|
||||
- raft.py (RaftOpticalFlow)
|
||||
|
||||
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
|
||||
uses (torch THWC uint8 input, no background removal, no visualization, no disk
|
||||
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
|
||||
have been replaced with equivalents from torch.nn.functional / einops. The
|
||||
RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
|
||||
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
|
||||
module never downloads weights at runtime.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _torch_resize_chw(image, size, interp, copy=True):
|
||||
"""Resize a CHW tensor.
|
||||
|
||||
``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one
|
||||
of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and
|
||||
the requested size matches the input, returns the input tensor as is
|
||||
(faster but callers must not mutate the result).
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
_, in_h, in_w = image.shape
|
||||
if isinstance(size, (int, float)) and not isinstance(size, bool):
|
||||
new_h = max(1, int(in_h * size))
|
||||
new_w = max(1, int(in_w * size))
|
||||
else:
|
||||
new_h, new_w = size
|
||||
|
||||
if (new_h, new_w) == (in_h, in_w):
|
||||
return image.clone() if copy else image
|
||||
|
||||
kwargs = {}
|
||||
if interp in ("bilinear", "bicubic"):
|
||||
kwargs["align_corners"] = False
|
||||
out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0]
|
||||
return out
|
||||
|
||||
|
||||
def _torch_remap_relative(image, dx, dy, interp="bilinear"):
|
||||
"""Relative remap of a CHW image via ``F.grid_sample``.
|
||||
|
||||
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
|
||||
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
if dx.shape != dy.shape:
|
||||
raise ValueError(
|
||||
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||
)
|
||||
_, h, w = image.shape
|
||||
|
||||
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
|
||||
y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None]
|
||||
|
||||
x_norm = (x_abs / (w - 1)) * 2 - 1
|
||||
y_norm = (y_abs / (h - 1)) * 2 - 1
|
||||
|
||||
grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype)
|
||||
out = F.grid_sample(
|
||||
image[None], grid, mode=interp, align_corners=True, padding_mode="zeros"
|
||||
)[0]
|
||||
return out
|
||||
|
||||
|
||||
def _torch_scatter_add_relative(image, dx, dy):
|
||||
"""Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets.
|
||||
|
||||
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
|
||||
interp='floor')``. Out-of-bounds targets are dropped.
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
in_c, in_h, in_w = image.shape
|
||||
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
|
||||
raise ValueError(
|
||||
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
|
||||
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
|
||||
)
|
||||
|
||||
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
|
||||
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
|
||||
|
||||
valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1)
|
||||
indices = (y * in_w + x).reshape(-1)[valid]
|
||||
|
||||
flat_image = rearrange(image, "c h w -> (h w) c")[valid]
|
||||
out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device)
|
||||
out.index_add_(0, indices, flat_image)
|
||||
return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise warping primitives (ported from noise_warp.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def unique_pixels(image):
|
||||
"""Find unique pixel values in a CHW tensor.
|
||||
|
||||
Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where
|
||||
``index_matrix[i, j]`` is the index of the unique color at that pixel.
|
||||
"""
|
||||
_, h, w = image.shape
|
||||
flat = rearrange(image, "c h w -> (h w) c")
|
||||
unique_colors, inverse_indices, counts = torch.unique(
|
||||
flat, dim=0, return_inverse=True, return_counts=True, sorted=False,
|
||||
)
|
||||
index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w)
|
||||
return unique_colors, counts, index_matrix
|
||||
|
||||
|
||||
def sum_indexed_values(image, index_matrix):
|
||||
"""For each unique index, sum the CHW image values at its pixels."""
|
||||
_, h, w = image.shape
|
||||
u = int(index_matrix.max().item()) + 1
|
||||
flat = rearrange(image, "c h w -> (h w) c")
|
||||
out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device)
|
||||
out.index_add_(0, index_matrix.view(-1), flat)
|
||||
return out
|
||||
|
||||
|
||||
def indexed_to_image(index_matrix, unique_colors):
|
||||
"""Build a CHW image from an index matrix and a (U, C) color table."""
|
||||
h, w = index_matrix.shape
|
||||
flat = unique_colors[index_matrix.view(-1)]
|
||||
return rearrange(flat, "(h w) c -> c h w", h=h, w=w)
|
||||
|
||||
|
||||
def regaussianize(noise):
|
||||
"""Variance-preserving re-sampling of a CHW noise tensor.
|
||||
|
||||
Wherever the noise contains groups of identical pixel values (e.g. after
|
||||
a nearest-neighbor warp that duplicated source pixels), adds zero-mean
|
||||
foreign noise within each group and scales by ``1/sqrt(count)`` so the
|
||||
output is unit-variance gaussian again.
|
||||
"""
|
||||
_, hs, ws = noise.shape
|
||||
_, counts, index_matrix = unique_pixels(noise[:1])
|
||||
|
||||
foreign_noise = torch.randn_like(noise)
|
||||
summed = sum_indexed_values(foreign_noise, index_matrix)
|
||||
meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1"))
|
||||
zeroed_foreign = foreign_noise - meaned
|
||||
|
||||
counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1"))
|
||||
|
||||
output = noise / counts_image ** 0.5 + zeroed_foreign
|
||||
return output, counts_image
|
||||
|
||||
|
||||
def xy_meshgrid_like_image(image):
|
||||
"""Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``."""
|
||||
_, h, w = image.shape
|
||||
y, x = torch.meshgrid(
|
||||
torch.arange(h, device=image.device, dtype=image.dtype),
|
||||
torch.arange(w, device=image.device, dtype=image.dtype),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([x, y])
|
||||
|
||||
|
||||
def noise_to_state(noise):
|
||||
"""Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise]."""
|
||||
zeros = torch.zeros_like(noise[:1])
|
||||
ones = torch.ones_like(noise[:1])
|
||||
return torch.cat([zeros, zeros, ones, noise])
|
||||
|
||||
|
||||
def state_to_noise(state):
|
||||
"""Unpack the noise channels from a state tensor."""
|
||||
return state[3:]
|
||||
|
||||
|
||||
def warp_state(state, flow):
|
||||
"""Warp a noise-warper state tensor along the given optical flow.
|
||||
|
||||
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
|
||||
``flow`` has shape ``(2, h, w)`` (= dx, dy).
|
||||
"""
|
||||
if flow.device != state.device:
|
||||
raise ValueError(
|
||||
f"warp_state: flow and state must be on the same device, "
|
||||
f"got flow={flow.device} state={state.device}"
|
||||
)
|
||||
if state.ndim != 3:
|
||||
raise ValueError(
|
||||
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
|
||||
)
|
||||
xyoc, h, w = state.shape
|
||||
if flow.shape != (2, h, w):
|
||||
raise ValueError(
|
||||
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
|
||||
)
|
||||
device = state.device
|
||||
|
||||
x_ch, y_ch = 0, 1
|
||||
xy = 2 # state[:xy] = [dx, dy]
|
||||
xyw = 3 # state[:xyw] = [dx, dy, ω]
|
||||
w_ch = 2 # state[w_ch] = ω
|
||||
c = xyoc - xyw
|
||||
oc = xyoc - xy
|
||||
if c <= 0:
|
||||
raise ValueError(
|
||||
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
|
||||
)
|
||||
if not (state[w_ch] > 0).all():
|
||||
raise ValueError("warp_state: all weights in state[2] must be > 0")
|
||||
|
||||
grid = xy_meshgrid_like_image(state)
|
||||
|
||||
init = torch.empty_like(state)
|
||||
init[:xy] = 0
|
||||
init[w_ch] = 1
|
||||
init[-c:] = 0
|
||||
|
||||
# --- Expansion branch: nearest-neighbor remap with negated flow ---
|
||||
pre_expand = torch.empty_like(state)
|
||||
pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest")
|
||||
pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest")
|
||||
pre_expand[w_ch][pre_expand[w_ch] == 0] = 1
|
||||
|
||||
# --- Shrink branch: scatter-add state into new positions ---
|
||||
pre_shrink = state.clone()
|
||||
pre_shrink[:xy] += flow
|
||||
|
||||
pos = (grid + pre_shrink[:xy]).round()
|
||||
in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h)
|
||||
pre_shrink = torch.where(~in_bounds[None], init, pre_shrink)
|
||||
|
||||
scat_xy = pre_shrink[:xy].round()
|
||||
pre_shrink[:xy] -= scat_xy
|
||||
pre_shrink[:xy] = 0 # xy_mode='none' in upstream
|
||||
|
||||
def scat(tensor):
|
||||
return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1])
|
||||
|
||||
# rp.torch_scatter_add_image on a bool tensor errors on modern torch;
|
||||
# scatter-sum a float ones tensor and threshold to get the mask instead.
|
||||
shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0
|
||||
|
||||
# Drop expansion samples at positions that will be filled by shrink.
|
||||
pre_expand = torch.where(shrink_mask, init, pre_expand)
|
||||
|
||||
# Regaussianize both branches together so duplicated-source groups are
|
||||
# counted globally, then split back apart.
|
||||
concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width
|
||||
concat[-c:], counts_image = regaussianize(concat[-c:])
|
||||
concat[w_ch] = concat[w_ch] / counts_image[0]
|
||||
concat[w_ch] = concat[w_ch].nan_to_num()
|
||||
pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2)
|
||||
|
||||
shrink = torch.empty_like(pre_shrink)
|
||||
shrink[w_ch] = scat(pre_shrink[w_ch][None])[0]
|
||||
shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None]
|
||||
shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat(
|
||||
pre_shrink[w_ch][None] ** 2
|
||||
).sqrt()
|
||||
|
||||
output = torch.where(shrink_mask, shrink, expand)
|
||||
output[w_ch] = output[w_ch] / output[w_ch].mean()
|
||||
output[w_ch] += 1e-5
|
||||
output[w_ch] **= 0.9999
|
||||
return output
|
||||
|
||||
|
||||
class NoiseWarper:
|
||||
"""Maintain a warpable noise state and emit gaussian noise per frame.
|
||||
|
||||
Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper:
|
||||
``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and
|
||||
``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults.
|
||||
"""
|
||||
|
||||
def __init__(self, c, h, w, device, dtype=torch.float32):
|
||||
if c <= 0 or h <= 0 or w <= 0:
|
||||
raise ValueError(
|
||||
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
|
||||
)
|
||||
self.c = c
|
||||
self.h = h
|
||||
self.w = w
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
noise = torch.randn(c, h, w, dtype=dtype, device=device)
|
||||
self._state = noise_to_state(noise)
|
||||
|
||||
@property
|
||||
def noise(self):
|
||||
# With scale_factor=1 the "downsample to respect weights" step is a
|
||||
# size-preserving no-op; the weight-variance correction math still
|
||||
# runs to stay faithful to upstream.
|
||||
n = state_to_noise(self._state)
|
||||
weights = self._state[2:3]
|
||||
return n * weights / (weights ** 2).sqrt()
|
||||
|
||||
def __call__(self, dx, dy):
|
||||
if dx.shape != dy.shape:
|
||||
raise ValueError(
|
||||
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||
)
|
||||
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
|
||||
_, oflowh, ofloww = flow.shape
|
||||
|
||||
flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True)
|
||||
flowh, floww = flow.shape[-2:]
|
||||
|
||||
# Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww
|
||||
# (channel-order appears swapped but harmless when H and W are scaled
|
||||
# by the same factor, which is always the case for our callers).
|
||||
flow[0] *= flowh / oflowh
|
||||
flow[1] *= floww / ofloww
|
||||
|
||||
self._state = warp_state(self._state, flow)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RAFT optical flow wrapper (ported from raft.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RaftOpticalFlow:
|
||||
"""RAFT-large wrapper around a pre-loaded torchvision model.
|
||||
|
||||
``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
|
||||
with its weights already populated; this class is load-agnostic so the
|
||||
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
|
||||
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device=None):
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.device = device
|
||||
self.model = model
|
||||
|
||||
def _preprocess(self, image_chw):
|
||||
image = image_chw.to(self.device, torch.float32)
|
||||
_, h, w = image.shape
|
||||
new_h = (h // 8) * 8
|
||||
new_w = (w // 8) * 8
|
||||
image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False)
|
||||
image = image * 2 - 1
|
||||
return image[None]
|
||||
|
||||
def __call__(self, from_image, to_image):
|
||||
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
|
||||
if from_image.shape != to_image.shape:
|
||||
raise ValueError(
|
||||
f"RaftOpticalFlow: from_image and to_image must match, "
|
||||
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
|
||||
)
|
||||
_, h, w = from_image.shape
|
||||
with torch.no_grad():
|
||||
img1 = self._preprocess(from_image)
|
||||
img2 = self._preprocess(to_image)
|
||||
list_of_flows = self.model(img1, img2)
|
||||
flow = list_of_flows[-1][0] # (2, new_h, new_w)
|
||||
if flow.shape[-2:] != (h, w):
|
||||
flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False)
|
||||
return flow
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Narrow entry point used by VOIDWarpedNoise
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_noise_from_video(
|
||||
video_frames: torch.Tensor,
|
||||
raft: RaftOpticalFlow,
|
||||
*,
|
||||
noise_channels: int = 16,
|
||||
resize_frames: float = 0.5,
|
||||
resize_flow: int = 8,
|
||||
downscale_factor: int = 32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Produce optical-flow-warped gaussian noise from a video.
|
||||
|
||||
Args:
|
||||
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
|
||||
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
|
||||
noise_channels: Channels in the output noise.
|
||||
resize_frames: Pre-RAFT frame scale factor.
|
||||
resize_flow: Post-flow up-scale factor applied to the optical flow;
|
||||
the internal noise state is allocated at
|
||||
``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``.
|
||||
downscale_factor: Area-pool factor applied to the noise before return;
|
||||
should evenly divide the internal noise resolution.
|
||||
device: Target device. Defaults to ``comfy.model_management.get_torch_device()``.
|
||||
|
||||
Returns:
|
||||
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
|
||||
"""
|
||||
if not isinstance(resize_flow, int) or resize_flow < 1:
|
||||
raise ValueError(
|
||||
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
|
||||
)
|
||||
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
|
||||
f"got {tuple(video_frames.shape)}"
|
||||
)
|
||||
if video_frames.dtype != torch.uint8:
|
||||
raise TypeError(
|
||||
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
|
||||
f"got dtype {video_frames.dtype}"
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
if device.type == "cpu":
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: running get_noise_from_video on CPU; this will be "
|
||||
"slow (minutes for ~45 frames). Use CUDA for interactive use."
|
||||
)
|
||||
|
||||
T = video_frames.shape[0]
|
||||
frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0
|
||||
if resize_frames != 1.0:
|
||||
new_h = max(1, int(frames.shape[2] * resize_frames))
|
||||
new_w = max(1, int(frames.shape[3] * resize_frames))
|
||||
frames = F.interpolate(frames, size=(new_h, new_w), mode="area")
|
||||
|
||||
_, _, H, W = frames.shape
|
||||
internal_h = resize_flow * H
|
||||
internal_w = resize_flow * W
|
||||
if internal_h % downscale_factor or internal_w % downscale_factor:
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: internal noise size %dx%d is not divisible by "
|
||||
"downscale_factor %d; output noise may have artifacts.",
|
||||
internal_h, internal_w, downscale_factor,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
warper = NoiseWarper(
|
||||
c=noise_channels, h=internal_h, w=internal_w, device=device,
|
||||
)
|
||||
down_h = warper.h // downscale_factor
|
||||
down_w = warper.w // downscale_factor
|
||||
output = torch.empty(
|
||||
(T, down_h, down_w, noise_channels), dtype=torch.float32, device=device,
|
||||
)
|
||||
|
||||
def downscale(noise_chw):
|
||||
# Area-pool to 1/downscale_factor then multiply by downscale_factor
|
||||
# to adjust std (sqrt of pool area == downscale_factor for a
|
||||
# square pool).
|
||||
down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False)
|
||||
return down * downscale_factor
|
||||
|
||||
output[0] = downscale(warper.noise).permute(1, 2, 0)
|
||||
|
||||
prev = frames[0]
|
||||
for i in range(1, T):
|
||||
curr = frames[i]
|
||||
flow = raft(prev, curr).to(device)
|
||||
warper(flow[0], flow[1])
|
||||
output[i] = downscale(warper.noise).permute(1, 2, 0)
|
||||
prev = curr
|
||||
|
||||
return output
|
||||
11
execution.py
11
execution.py
@ -15,6 +15,7 @@ import torch
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.model_prefetch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
@ -1017,7 +1019,12 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
combo_options = extra_info.get("options", [])
|
||||
else:
|
||||
combo_options = input_type
|
||||
if val not in combo_options:
|
||||
is_multiselect = extra_info.get("multiselect", False)
|
||||
if is_multiselect and isinstance(val, list):
|
||||
invalid_vals = [v for v in val if v not in combo_options]
|
||||
else:
|
||||
invalid_vals = [val] if val not in combo_options else []
|
||||
if invalid_vals:
|
||||
input_config = info
|
||||
list_info = ""
|
||||
|
||||
@ -1032,7 +1039,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
|
||||
error = {
|
||||
"type": "value_not_in_list",
|
||||
"message": "Value not in list",
|
||||
"details": f"{x}: '{val}' not in {list_info}",
|
||||
"details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}",
|
||||
"extra_info": {
|
||||
"input_name": x,
|
||||
"input_config": input_config,
|
||||
|
||||
@ -28,7 +28,7 @@
|
||||
#config for a1111 ui
|
||||
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
|
||||
|
||||
#a111:
|
||||
#a1111:
|
||||
# base_path: path/to/stable-diffusion-webui/
|
||||
# checkpoints: models/Stable-diffusion
|
||||
# configs: models/Stable-diffusion
|
||||
|
||||
@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
@ -432,7 +434,9 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
|
||||
prefix_len = len(os.path.basename(filename_prefix))
|
||||
prefix = filename[:prefix_len + 1]
|
||||
try:
|
||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||
remainder = filename[prefix_len + 1:]
|
||||
base_remainder = remainder.split('.')[0]
|
||||
digits = int(base_remainder.split('_')[0])
|
||||
except:
|
||||
digits = 0
|
||||
return digits, prefix
|
||||
|
||||
10
main.py
10
main.py
@ -1,13 +1,21 @@
|
||||
import comfy.options
|
||||
comfy.options.enable_args_parsing()
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
if args.list_feature_flags:
|
||||
import json
|
||||
from comfy_api.feature_flags import CLI_FEATURE_FLAG_REGISTRY
|
||||
print(json.dumps(CLI_FEATURE_FLAG_REGISTRY, indent=2)) # noqa: T201
|
||||
raise SystemExit(0)
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import shutil
|
||||
import importlib.metadata
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from comfy.cli_args import enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
|
||||
0
models/optical_flow/put_optical_flow_models_here
Normal file
0
models/optical_flow/put_optical_flow_models_here
Normal file
@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
|
||||
if destination.shape[-1] < source.shape[-1]:
|
||||
source = source[...,:destination.shape[-1]]
|
||||
elif destination.shape[-1] > source.shape[-1]:
|
||||
destination = torch.nn.functional.pad(destination, (0, 1))
|
||||
destination[..., -1] = 1.0
|
||||
source = torch.nn.functional.pad(source, (0, 1))
|
||||
source[..., -1] = 1.0
|
||||
return destination, source
|
||||
|
||||
183
nodes.py
183
nodes.py
@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader):
|
||||
|
||||
class VAELoader:
|
||||
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"]
|
||||
image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
|
||||
|
||||
@staticmethod
|
||||
def vae_list(s):
|
||||
vaes = folder_paths.get_filename_list("vae")
|
||||
approx_vaes = folder_paths.get_filename_list("vae_approx")
|
||||
sdxl_taesd_enc = False
|
||||
sdxl_taesd_dec = False
|
||||
sd1_taesd_enc = False
|
||||
sd1_taesd_dec = False
|
||||
sd3_taesd_enc = False
|
||||
sd3_taesd_dec = False
|
||||
f1_taesd_enc = False
|
||||
f1_taesd_dec = False
|
||||
|
||||
have_img_encoder, have_img_decoder = set(), set()
|
||||
for v in approx_vaes:
|
||||
if v.startswith("taesd_decoder."):
|
||||
sd1_taesd_dec = True
|
||||
elif v.startswith("taesd_encoder."):
|
||||
sd1_taesd_enc = True
|
||||
elif v.startswith("taesdxl_decoder."):
|
||||
sdxl_taesd_dec = True
|
||||
elif v.startswith("taesdxl_encoder."):
|
||||
sdxl_taesd_enc = True
|
||||
elif v.startswith("taesd3_decoder."):
|
||||
sd3_taesd_dec = True
|
||||
elif v.startswith("taesd3_encoder."):
|
||||
sd3_taesd_enc = True
|
||||
elif v.startswith("taef1_encoder."):
|
||||
f1_taesd_dec = True
|
||||
elif v.startswith("taef1_decoder."):
|
||||
f1_taesd_enc = True
|
||||
else:
|
||||
parts = v.split("_", 1)
|
||||
if len(parts) != 2 or parts[0] not in s.image_taes:
|
||||
for tae in s.video_taes:
|
||||
if v.startswith(tae):
|
||||
vaes.append(v)
|
||||
|
||||
if sd1_taesd_dec and sd1_taesd_enc:
|
||||
vaes.append("taesd")
|
||||
if sdxl_taesd_dec and sdxl_taesd_enc:
|
||||
vaes.append("taesdxl")
|
||||
if sd3_taesd_dec and sd3_taesd_enc:
|
||||
vaes.append("taesd3")
|
||||
if f1_taesd_dec and f1_taesd_enc:
|
||||
vaes.append("taef1")
|
||||
break
|
||||
continue
|
||||
if parts[1].startswith("encoder."):
|
||||
have_img_encoder.add(parts[0])
|
||||
elif parts[1].startswith("decoder."):
|
||||
have_img_decoder.add(parts[0])
|
||||
vaes += [k for k in have_img_decoder if k in have_img_encoder]
|
||||
vaes.append("pixel_space")
|
||||
return vaes
|
||||
|
||||
@ -827,6 +803,11 @@ class VAELoader:
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
|
||||
if vae_name == "taef2":
|
||||
if metadata is None:
|
||||
metadata = {"tae_latent_channels": 128}
|
||||
else:
|
||||
metadata["tae_latent_channels"] = 128
|
||||
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
|
||||
vae.throw_exception_if_invalid()
|
||||
return (vae,)
|
||||
@ -977,7 +958,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -987,7 +968,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -1713,26 +1694,27 @@ class LoadImage:
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "load_image"
|
||||
|
||||
def load_image(self, image):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
components = InputImpl.VideoFromFile(image_path).get_components()
|
||||
if components.images.shape[0] > 0:
|
||||
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
|
||||
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
|
||||
|
||||
# This code is left here to handle animated webp which pyav does not support loading
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
|
||||
if len(output_images) == 0:
|
||||
@ -1747,25 +1729,15 @@ class LoadImage:
|
||||
if 'A' in i.getbands():
|
||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
elif i.mode == 'P' and 'transparency' in i.info:
|
||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return (output_image, output_mask)
|
||||
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
@ -1782,57 +1754,49 @@ class LoadImage:
|
||||
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
|
||||
class LoadImageMask(LoadImage):
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||
return {"required":
|
||||
{"image": (sorted(files), {"image_upload": True}),
|
||||
"channel": (s._color_channels, ), }
|
||||
}
|
||||
types = super().INPUT_TYPES()
|
||||
return {
|
||||
"required": {
|
||||
**types["required"],
|
||||
"channel": (s._color_channels, )
|
||||
}
|
||||
}
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "load_image"
|
||||
def load_image(self, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
i = node_helpers.pillow(Image.open, image_path)
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
if i.getbands() != ("R", "G", "B", "A"):
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
i = i.convert("RGBA")
|
||||
mask = None
|
||||
FUNCTION = "load_image_mask"
|
||||
|
||||
def load_image_mask(self, image, channel):
|
||||
image_tensor, mask_tensor = super().load_image(image)
|
||||
c = channel[0].upper()
|
||||
if c in i.getbands():
|
||||
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(mask)
|
||||
if c == 'A':
|
||||
mask = 1. - mask
|
||||
|
||||
if c == 'A':
|
||||
return (mask_tensor,)
|
||||
|
||||
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
|
||||
|
||||
if channel_idx < image_tensor.shape[-1]:
|
||||
return (image_tensor[..., channel_idx].clone(),)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
return (mask.unsqueeze(0),)
|
||||
empty_mask = torch.zeros(
|
||||
image_tensor.shape[:-1],
|
||||
dtype=image_tensor.dtype,
|
||||
device=image_tensor.device
|
||||
)
|
||||
return (empty_mask,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image, channel):
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, image):
|
||||
if not folder_paths.exists_annotated_filepath(image):
|
||||
return "Invalid image file: {}".format(image)
|
||||
|
||||
return True
|
||||
return super().IS_CHANGED(image)
|
||||
|
||||
|
||||
class LoadImageOutput(LoadImage):
|
||||
@ -1923,7 +1887,7 @@ class ImageInvert:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "invert"
|
||||
|
||||
CATEGORY = "image"
|
||||
CATEGORY = "image/color"
|
||||
|
||||
def invert(self, image):
|
||||
s = 1.0 - image
|
||||
@ -1939,7 +1903,7 @@ class ImageBatch:
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "batch"
|
||||
|
||||
CATEGORY = "image"
|
||||
CATEGORY = "image/batch"
|
||||
DEPRECATED = True
|
||||
|
||||
def batch(self, image1, image2):
|
||||
@ -1996,7 +1960,7 @@ class ImagePadForOutpaint:
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
FUNCTION = "expand_image"
|
||||
|
||||
CATEGORY = "image"
|
||||
CATEGORY = "image/transform"
|
||||
|
||||
def expand_image(self, image, left, top, right, bottom, feathering):
|
||||
d1, d2, d3, d4 = image.size()
|
||||
@ -2139,7 +2103,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||
"ControlNetApply": "Apply ControlNet (OLD)",
|
||||
"ControlNetApply": "Apply ControlNet (DEPRECATED)",
|
||||
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||
# Latent
|
||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||
@ -2157,6 +2121,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LatentFromBatch" : "Latent From Batch",
|
||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||
# Image
|
||||
"EmptyImage": "Empty Image",
|
||||
"SaveImage": "Save Image",
|
||||
"PreviewImage": "Preview Image",
|
||||
"LoadImage": "Load Image",
|
||||
@ -2164,15 +2129,15 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"LoadImageOutput": "Load Image (from Outputs)",
|
||||
"ImageScale": "Upscale Image",
|
||||
"ImageScaleBy": "Upscale Image By",
|
||||
"ImageInvert": "Invert Image",
|
||||
"ImageInvert": "Invert Image Colors",
|
||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||
"ImageBatch": "Batch Images",
|
||||
"ImageCrop": "Image Crop",
|
||||
"ImageStitch": "Image Stitch",
|
||||
"ImageBlend": "Image Blend",
|
||||
"ImageBlur": "Image Blur",
|
||||
"ImageQuantize": "Image Quantize",
|
||||
"ImageSharpen": "Image Sharpen",
|
||||
"ImageBatch": "Batch Images (DEPRECATED)",
|
||||
"ImageCrop": "Crop Image",
|
||||
"ImageStitch": "Stitch Images",
|
||||
"ImageBlend": "Blend Images",
|
||||
"ImageBlur": "Blur Image",
|
||||
"ImageQuantize": "Quantize Image",
|
||||
"ImageSharpen": "Sharpen Image",
|
||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||
"GetImageSize": "Get Image Size",
|
||||
# _for_testing
|
||||
@ -2297,7 +2262,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
|
||||
return False
|
||||
else:
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).")
|
||||
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.warning(traceback.format_exc())
|
||||
@ -2447,6 +2412,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_ar_video.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_glsl.py",
|
||||
@ -2463,7 +2429,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_curve.py",
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py"
|
||||
"nodes_sam3.py",
|
||||
"nodes_void.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
142
openapi.yaml
142
openapi.yaml
@ -631,7 +631,7 @@ paths:
|
||||
operationId: getFeatures
|
||||
tags: [system]
|
||||
summary: Get enabled feature flags
|
||||
description: Returns a dictionary of feature flag names to their enabled state.
|
||||
description: Returns a dictionary of feature flag names to their enabled state. Cloud deployments may include additional typed fields alongside the boolean flags.
|
||||
responses:
|
||||
"200":
|
||||
description: Feature flags
|
||||
@ -641,6 +641,43 @@ paths:
|
||||
type: object
|
||||
additionalProperties:
|
||||
type: boolean
|
||||
properties:
|
||||
max_upload_size:
|
||||
type: integer
|
||||
format: int64
|
||||
minimum: 0
|
||||
description: "Maximum file upload size in bytes."
|
||||
free_tier_credits:
|
||||
type: integer
|
||||
format: int32
|
||||
minimum: 0
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Credits available to free-tier users. Local ComfyUI returns null."
|
||||
posthog_api_host:
|
||||
type: string
|
||||
format: uri
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] PostHog analytics proxy URL for frontend telemetry. Local ComfyUI returns null."
|
||||
max_concurrent_jobs:
|
||||
type: integer
|
||||
format: int32
|
||||
minimum: 0
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Maximum concurrent jobs the authenticated user can run. Local ComfyUI returns null."
|
||||
workflow_templates_version:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Version identifier for the workflow templates bundle. Local ComfyUI returns null."
|
||||
workflow_templates_source:
|
||||
type: string
|
||||
nullable: true
|
||||
enum: [dynamic_config_override, workflow_templates_version_json]
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] How the templates version was resolved. Local ComfyUI returns null."
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Node / Object Info
|
||||
@ -1497,6 +1534,24 @@ paths:
|
||||
type: string
|
||||
enum: [asc, desc]
|
||||
description: Sort direction
|
||||
- name: job_ids
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
|
||||
- name: include_public
|
||||
in: query
|
||||
schema:
|
||||
type: boolean
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Include workspace-public assets in addition to the caller's own."
|
||||
- name: asset_hash
|
||||
in: query
|
||||
schema:
|
||||
type: string
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Filter by exact content hash."
|
||||
responses:
|
||||
"200":
|
||||
description: Asset list
|
||||
@ -1542,6 +1597,49 @@ paths:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of an existing asset to use as the preview image
|
||||
id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] URL-based asset upload. Caller supplies a URL instead of a file body; the server fetches the content."
|
||||
required:
|
||||
- url
|
||||
properties:
|
||||
url:
|
||||
type: string
|
||||
format: uri
|
||||
description: "[cloud-only] URL of the file to import as an asset"
|
||||
name:
|
||||
type: string
|
||||
description: Display name for the asset
|
||||
tags:
|
||||
type: string
|
||||
description: Comma-separated tags
|
||||
user_metadata:
|
||||
type: string
|
||||
description: JSON-encoded user metadata
|
||||
hash:
|
||||
type: string
|
||||
description: "Blake3 hash of the file content (e.g. blake3:abc123...)"
|
||||
mime_type:
|
||||
type: string
|
||||
description: MIME type of the file (overrides auto-detected type)
|
||||
preview_id:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of an existing asset to use as the preview image
|
||||
id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
|
||||
responses:
|
||||
"201":
|
||||
description: Asset created
|
||||
@ -1580,6 +1678,11 @@ paths:
|
||||
user_metadata:
|
||||
type: object
|
||||
additionalProperties: true
|
||||
mime_type:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] MIME type of the content, so the type is preserved without re-inspecting content. Ignored by local ComfyUI."
|
||||
responses:
|
||||
"201":
|
||||
description: Asset created from hash
|
||||
@ -1644,6 +1747,11 @@ paths:
|
||||
type: string
|
||||
format: uuid
|
||||
description: ID of the asset to use as the preview
|
||||
mime_type:
|
||||
type: string
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] MIME type override when auto-detection was wrong. Ignored by local ComfyUI."
|
||||
responses:
|
||||
"200":
|
||||
description: Asset updated
|
||||
@ -1999,6 +2107,18 @@ components:
|
||||
items:
|
||||
type: string
|
||||
description: List of node IDs to execute (partial graph execution)
|
||||
workflow_id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Cloud workflow entity ID for tracking and gallery association. Ignored by local ComfyUI."
|
||||
workflow_version_id:
|
||||
type: string
|
||||
format: uuid
|
||||
nullable: true
|
||||
x-runtime: [cloud]
|
||||
description: "[cloud-only] Cloud workflow version ID for pinning execution to a specific version. Ignored by local ComfyUI."
|
||||
|
||||
PromptResponse:
|
||||
type: object
|
||||
@ -2347,7 +2467,12 @@ components:
|
||||
description: Device type (cuda, mps, cpu, etc.)
|
||||
index:
|
||||
type: number
|
||||
description: Device index
|
||||
nullable: true
|
||||
description: |
|
||||
Device index within its type (e.g. CUDA ordinal for `cuda:0`,
|
||||
`cuda:1`). `null` for devices with no index, including the CPU
|
||||
device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index`
|
||||
is `None`).
|
||||
vram_total:
|
||||
type: number
|
||||
description: Total VRAM in bytes
|
||||
@ -2503,7 +2628,18 @@ components:
|
||||
description: Alternative search terms for finding this node
|
||||
essentials_category:
|
||||
type: string
|
||||
description: Category override used by the essentials pack
|
||||
nullable: true
|
||||
description: |
|
||||
Category override used by the essentials pack. The
|
||||
`essentials_category` key may be present with a string value,
|
||||
present and `null`, or absent entirely:
|
||||
|
||||
- V1 nodes: `essentials_category` is **omitted** when the node
|
||||
class doesn't define an `ESSENTIALS_CATEGORY` attribute, and
|
||||
**`null`** if the attribute is explicitly set to `None`.
|
||||
- V3 nodes (`comfy_api.latest.io`): `essentials_category` is
|
||||
**always present**, and **`null`** for nodes whose `Schema`
|
||||
doesn't populate it.
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# Models
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.42.15
|
||||
comfyui-workflow-templates==0.9.63
|
||||
comfyui-frontend-package==1.43.17
|
||||
comfyui-workflow-templates==0.9.69
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
torchsde
|
||||
@ -19,7 +19,7 @@ scipy
|
||||
tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy>=2.0
|
||||
SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
|
||||
17
server.py
17
server.py
@ -1,3 +1,4 @@
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
@ -559,7 +560,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
|
||||
if 'channel' not in request.rel_url.query:
|
||||
channel = 'rgba'
|
||||
@ -579,7 +580,7 @@ class PromptServer():
|
||||
buffer.seek(0)
|
||||
|
||||
return web.Response(body=buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
|
||||
elif channel == 'a':
|
||||
with Image.open(file) as img:
|
||||
@ -596,7 +597,7 @@ class PromptServer():
|
||||
alpha_buffer.seek(0)
|
||||
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
|
||||
else:
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
@ -613,7 +614,7 @@ class PromptServer():
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Disposition": f"attachment; filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
@ -1245,7 +1246,13 @@ class PromptServer():
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
|
||||
await site.start()
|
||||
try:
|
||||
await site.start()
|
||||
except OSError as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
|
||||
raise SystemExit(1)
|
||||
raise
|
||||
|
||||
if not hasattr(self, 'address'):
|
||||
self.address = address #TODO: remove this
|
||||
|
||||
78
tests-unit/comfy_api_test/multicombo_serialization_test.py
Normal file
78
tests-unit/comfy_api_test/multicombo_serialization_test.py
Normal file
@ -0,0 +1,78 @@
|
||||
from comfy_api.latest._io import Combo, MultiCombo
|
||||
|
||||
|
||||
def test_multicombo_serializes_multi_select_as_object():
|
||||
multi_combo = MultiCombo.Input(
|
||||
id="providers",
|
||||
options=["a", "b", "c"],
|
||||
default=["a"],
|
||||
)
|
||||
|
||||
serialized = multi_combo.as_dict()
|
||||
|
||||
assert serialized["multiselect"] is True
|
||||
assert "multi_select" in serialized
|
||||
assert serialized["multi_select"] == {}
|
||||
|
||||
|
||||
def test_multicombo_serializes_multi_select_with_placeholder_and_chip():
|
||||
multi_combo = MultiCombo.Input(
|
||||
id="providers",
|
||||
options=["a", "b", "c"],
|
||||
default=["a"],
|
||||
placeholder="Select providers",
|
||||
chip=True,
|
||||
)
|
||||
|
||||
serialized = multi_combo.as_dict()
|
||||
|
||||
assert serialized["multiselect"] is True
|
||||
assert serialized["multi_select"] == {
|
||||
"placeholder": "Select providers",
|
||||
"chip": True,
|
||||
}
|
||||
|
||||
|
||||
def test_combo_does_not_serialize_multiselect():
|
||||
"""Regular Combo should not have multiselect in its serialized output."""
|
||||
combo = Combo.Input(
|
||||
id="choice",
|
||||
options=["a", "b", "c"],
|
||||
)
|
||||
|
||||
serialized = combo.as_dict()
|
||||
|
||||
# Combo sets multiselect=False, but prune_dict keeps False (not None),
|
||||
# so it should be present but False
|
||||
assert serialized.get("multiselect") is False
|
||||
assert "multi_select" not in serialized
|
||||
|
||||
|
||||
def _validate_combo_values(val, combo_options, is_multiselect):
|
||||
"""Reproduce the validation logic from execution.py for testing."""
|
||||
if is_multiselect and isinstance(val, list):
|
||||
return [v for v in val if v not in combo_options]
|
||||
else:
|
||||
return [val] if val not in combo_options else []
|
||||
|
||||
|
||||
def test_multicombo_validation_accepts_valid_list():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values(["a", "b"], options, True) == []
|
||||
|
||||
|
||||
def test_multicombo_validation_rejects_invalid_values():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values(["a", "x"], options, True) == ["x"]
|
||||
|
||||
|
||||
def test_multicombo_validation_accepts_empty_list():
|
||||
options = ["a", "b", "c"]
|
||||
assert _validate_combo_values([], options, True) == []
|
||||
|
||||
|
||||
def test_combo_validation_rejects_list_even_with_valid_items():
|
||||
"""A regular Combo should not accept a list value."""
|
||||
options = ["a", "b", "c"]
|
||||
invalid = _validate_combo_values(["a", "b"], options, False)
|
||||
assert len(invalid) > 0
|
||||
109
tests-unit/deploy_environment_test.py
Normal file
109
tests-unit/deploy_environment_test.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Tests for comfy.deploy_environment."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy import deploy_environment
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_cache_and_install_dir(tmp_path, monkeypatch):
|
||||
"""Reset the functools cache and point the ComfyUI install dir at a tmp dir for each test."""
|
||||
get_deploy_environment.cache_clear()
|
||||
monkeypatch.setattr(deploy_environment, "_COMFY_INSTALL_DIR", str(tmp_path))
|
||||
yield
|
||||
get_deploy_environment.cache_clear()
|
||||
|
||||
|
||||
def _write_env_file(tmp_path, content: str) -> str:
|
||||
"""Write the env file with exact content (no newline translation).
|
||||
|
||||
`newline=""` disables Python's text-mode newline translation so the bytes
|
||||
on disk match the literal string passed in, regardless of host OS.
|
||||
Newline-style tests (CRLF, lone CR) rely on this.
|
||||
"""
|
||||
path = os.path.join(str(tmp_path), ".comfy_environment")
|
||||
with open(path, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(content)
|
||||
return path
|
||||
|
||||
|
||||
class TestGetDeployEnvironment:
|
||||
def test_returns_local_git_when_file_missing(self):
|
||||
assert get_deploy_environment() == "local-git"
|
||||
|
||||
def test_reads_value_from_file(self, tmp_path):
|
||||
_write_env_file(tmp_path, "local-desktop2-standalone\n")
|
||||
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||
|
||||
def test_strips_trailing_whitespace_and_newline(self, tmp_path):
|
||||
_write_env_file(tmp_path, " local-desktop2-standalone \n")
|
||||
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||
|
||||
def test_only_first_line_is_used(self, tmp_path):
|
||||
_write_env_file(tmp_path, "first-line\nsecond-line\n")
|
||||
assert get_deploy_environment() == "first-line"
|
||||
|
||||
def test_crlf_line_ending(self, tmp_path):
|
||||
# Windows editors often save text files with CRLF line endings.
|
||||
# The CR must not end up in the returned value.
|
||||
_write_env_file(tmp_path, "local-desktop2-standalone\r\n")
|
||||
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||
|
||||
def test_crlf_multiline_only_first_line_used(self, tmp_path):
|
||||
_write_env_file(tmp_path, "first-line\r\nsecond-line\r\n")
|
||||
assert get_deploy_environment() == "first-line"
|
||||
|
||||
def test_crlf_with_surrounding_whitespace(self, tmp_path):
|
||||
_write_env_file(tmp_path, " local-desktop2-standalone \r\n")
|
||||
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||
|
||||
def test_lone_cr_line_ending(self, tmp_path):
|
||||
# Classic-Mac / some legacy editors use a bare CR.
|
||||
# Universal-newlines decoding treats it as a line terminator too.
|
||||
_write_env_file(tmp_path, "local-desktop2-standalone\r")
|
||||
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||
|
||||
def test_empty_file_falls_back_to_default(self, tmp_path):
|
||||
_write_env_file(tmp_path, "")
|
||||
assert get_deploy_environment() == "local-git"
|
||||
|
||||
def test_empty_after_whitespace_strip_falls_back_to_default(self, tmp_path):
|
||||
_write_env_file(tmp_path, " \n")
|
||||
assert get_deploy_environment() == "local-git"
|
||||
|
||||
def test_strips_control_chars_within_first_line(self, tmp_path):
|
||||
# Embedded NUL/control chars in the value should be stripped
|
||||
# (header-injection / smuggling protection).
|
||||
_write_env_file(tmp_path, "abc\x00\x07xyz\n")
|
||||
assert get_deploy_environment() == "abcxyz"
|
||||
|
||||
def test_strips_non_ascii_characters(self, tmp_path):
|
||||
_write_env_file(tmp_path, "café-é\n")
|
||||
assert get_deploy_environment() == "caf-"
|
||||
|
||||
def test_caps_read_at_128_bytes(self, tmp_path):
|
||||
# A single huge line with no newline must not be fully read into memory.
|
||||
huge = "x" * 10_000
|
||||
_write_env_file(tmp_path, huge)
|
||||
result = get_deploy_environment()
|
||||
assert result == "x" * 128
|
||||
|
||||
def test_result_is_cached_across_calls(self, tmp_path):
|
||||
path = _write_env_file(tmp_path, "first_value\n")
|
||||
assert get_deploy_environment() == "first_value"
|
||||
# Overwrite the file — cached value should still be returned.
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.write("second_value\n")
|
||||
assert get_deploy_environment() == "first_value"
|
||||
|
||||
def test_unreadable_file_falls_back_to_default(self, tmp_path, monkeypatch):
|
||||
_write_env_file(tmp_path, "should_not_be_used\n")
|
||||
|
||||
def _boom(*args, **kwargs):
|
||||
raise OSError("simulated read failure")
|
||||
|
||||
monkeypatch.setattr("builtins.open", _boom)
|
||||
assert get_deploy_environment() == "local-git"
|
||||
@ -1,10 +1,15 @@
|
||||
"""Tests for feature flags functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_api.feature_flags import (
|
||||
get_connection_feature,
|
||||
supports_feature,
|
||||
get_server_features,
|
||||
CLI_FEATURE_FLAG_REGISTRY,
|
||||
SERVER_FEATURE_FLAGS,
|
||||
_coerce_flag_value,
|
||||
_parse_cli_feature_flags,
|
||||
)
|
||||
|
||||
|
||||
@ -96,3 +101,83 @@ class TestFeatureFlags:
|
||||
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
|
||||
assert result is False
|
||||
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
|
||||
|
||||
|
||||
class TestCoerceFlagValue:
|
||||
"""Test suite for _coerce_flag_value."""
|
||||
|
||||
def test_registered_bool_true(self):
|
||||
assert _coerce_flag_value("show_signin_button", "true") is True
|
||||
assert _coerce_flag_value("show_signin_button", "True") is True
|
||||
|
||||
def test_registered_bool_false(self):
|
||||
assert _coerce_flag_value("show_signin_button", "false") is False
|
||||
assert _coerce_flag_value("show_signin_button", "FALSE") is False
|
||||
|
||||
def test_unregistered_key_stays_string(self):
|
||||
assert _coerce_flag_value("unknown_flag", "true") == "true"
|
||||
assert _coerce_flag_value("unknown_flag", "42") == "42"
|
||||
|
||||
def test_bool_typo_raises(self):
|
||||
"""Strict bool: typos like 'ture' or 'yes' must raise so the flag can be dropped."""
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_flag_value("show_signin_button", "ture")
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_flag_value("show_signin_button", "yes")
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_flag_value("show_signin_button", "1")
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_flag_value("show_signin_button", "")
|
||||
|
||||
def test_failed_int_coercion_raises(self, monkeypatch):
|
||||
"""Malformed values for typed flags must raise; caller decides what to do."""
|
||||
monkeypatch.setitem(
|
||||
CLI_FEATURE_FLAG_REGISTRY,
|
||||
"test_int_flag",
|
||||
{"type": "int", "default": 0, "description": "test"},
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_flag_value("test_int_flag", "not_a_number")
|
||||
|
||||
|
||||
class TestParseCliFeatureFlags:
|
||||
"""Test suite for _parse_cli_feature_flags."""
|
||||
|
||||
def test_single_flag(self, monkeypatch):
|
||||
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button=true"]})())
|
||||
result = _parse_cli_feature_flags()
|
||||
assert result == {"show_signin_button": True}
|
||||
|
||||
def test_missing_equals_defaults_to_true(self, monkeypatch):
|
||||
"""Bare flag without '=' is treated as the string 'true' (and coerced if registered)."""
|
||||
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button", "valid=1"]})())
|
||||
result = _parse_cli_feature_flags()
|
||||
assert result == {"show_signin_button": True, "valid": "1"}
|
||||
|
||||
def test_empty_key_skipped(self, monkeypatch):
|
||||
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["=value", "valid=1"]})())
|
||||
result = _parse_cli_feature_flags()
|
||||
assert result == {"valid": "1"}
|
||||
|
||||
def test_invalid_bool_value_dropped(self, monkeypatch, caplog):
|
||||
"""A typo'd bool value must be dropped entirely, not silently set to False
|
||||
and not stored as a raw string. A warning must be logged."""
|
||||
monkeypatch.setattr(
|
||||
"comfy_api.feature_flags.args",
|
||||
type("Args", (), {"feature_flag": ["show_signin_button=ture", "valid=1"]})(),
|
||||
)
|
||||
with caplog.at_level("WARNING"):
|
||||
result = _parse_cli_feature_flags()
|
||||
assert result == {"valid": "1"}
|
||||
assert "show_signin_button" not in result
|
||||
assert any("show_signin_button" in r.message and "drop" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
class TestCliFeatureFlagRegistry:
|
||||
"""Test suite for the CLI feature flag registry."""
|
||||
|
||||
def test_registry_entries_have_required_fields(self):
|
||||
for key, info in CLI_FEATURE_FLAG_REGISTRY.items():
|
||||
assert "type" in info, f"{key} missing 'type'"
|
||||
assert "default" in info, f"{key} missing 'default'"
|
||||
assert "description" in info, f"{key} missing 'description'"
|
||||
|
||||
@ -69,7 +69,11 @@ async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
|
||||
assert len(result) == 1
|
||||
assert result[0]["path"] == "file1.txt"
|
||||
assert "size" in result[0]
|
||||
assert "modified" in result[0]
|
||||
assert isinstance(result[0]["modified"], int)
|
||||
assert isinstance(result[0]["created"], int)
|
||||
# Verify millisecond magnitude (timestamps after year 2000 in ms are > 946684800000)
|
||||
assert result[0]["modified"] > 946684800000
|
||||
assert result[0]["created"] > 946684800000
|
||||
|
||||
|
||||
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user