Compare commits

...

20 Commits

Author SHA1 Message Date
Talmaj
8affed1944
Merge b8688389e5 into f6d5068ac0 2026-05-03 02:44:13 -04:00
Alexis Rolland
f6d5068ac0
Update README (#13679)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Updated the README to include a new screenshot, improved description and add Ernie Image to supported models.
2026-05-03 12:20:17 +08:00
Jukka Seppänen
be95871adc
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support

* parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize

* Cleanup, video fixes

* cleanup, enable fused rms norm by default

* update comment

* Cleanup

* Update sd.py

* Various fixes

* Add fp8 scaled embedding support

* small fixes

* Translate think tokens

* Fix image encoder attention mask type

So it works with basic attention

* Handle thinking tokens different only for Gemma4

* Code cleanup

* Update nodes_textgen.py

* Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code

* Default to fused rms_norm

* Update gemma4.py
2026-05-02 22:46:15 -04:00
Alexander Piskun
f756d801a1
[Partner Nodes] Topaz Astra 2 model (#13672)
* feat(api-nodes): add Topaz Astra 2 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat(api-nodes): make Astra 2 the default Topaz upscaler model

Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so
"Astra 2" appears first, surfacing it as the default selection.

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Marwan Mostafa <marawan206@gmail.com>
2026-05-02 19:29:00 -07:00
Daxiong (Lin)
1d23a875ed
chore: update workflow templates to v0.9.68 (#13678) 2026-05-03 10:06:55 +08:00
comfyanonymous
ef6722f6be
Some cleanups to the load image node. (#13677) 2026-05-02 20:34:27 -04:00
rattus
783782d5d7
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.

* plan

* ops: move cpu handler up to the caller

* ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.

* ops: implement block prefetching API

allow a model to construct a prefetch list and operate it for increased
async offload.

* ltxv2: Implement block prefetching

* Implement lora async offload

Implement async offload of loras.
2026-05-02 19:23:24 -04:00
Talmaj Marinc
b8688389e5 Create a dedicated node for ar_sampler. 2026-05-01 22:21:46 +02:00
Talmaj Marinc
4a63ef01ed Add better error handling for a custom ar_video sampler. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
b5d1cdb2a8 Base frame_seq_len on the padded token grid. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
df0a0dae36 Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
6f99c8086f Remove ar_convert, now present in hg repackaged model repo. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
586647afdb Fix 'Process the tail block instead of truncating it', fix 'Don't mutate the patcher's shared transformer_options in place'. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
6739b76a20 Remove dedicated ARLoader. 2026-05-01 11:37:19 +02:00
Talmaj Marinc
7caf1e801c Refactor CausalWanModel to inherit from WanModel. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
3bd42bfaa4 Rewrite causual forcing using custom sampler with KSampler node. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
09e46f8509 Apply ruff. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
9d40dfbab8 Rename causual forcing to using more general auto regressive naming convention. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
0edbe6d8dc Fix CausalForcingSampler. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
2cdc4389d7 Initial commit causual_forcing. 2026-05-01 11:34:07 +02:00
27 changed files with 2577 additions and 127 deletions

View File

@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular visual AI engine and application.**
**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@ -31,10 +31,15 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/4aab0bef-b413-4595-9766-a2c134676d27" />
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@ -77,6 +82,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
@torch.no_grad()
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
num_frame_per_block=1):
"""
Autoregressive video sampler: block-by-block denoising with KV cache
and flow-match re-noising for Causal Forcing / Self-Forcing models.
Requires a Causal-WAN compatible model (diffusion_model must expose
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
All AR-loop parameters are passed via the SamplerARVideo node, not read
from the checkpoint or transformer_options.
"""
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get("model_options", {})
transformer_options = model_options.get("transformer_options", {})
if x.ndim != 5:
raise ValueError(
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
)
inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
raise TypeError(
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
"does not support this interface — choose a different sampler."
)
seed = extra_args.get("seed", 0)
bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
device = x.device
model_dtype = inner_model.get_dtype()
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0
try:
for block_idx in trange(num_blocks, disable=disable):
bf = min(num_frame_per_block, lat_t - current_start_frame)
fs, fe = current_start_frame, current_start_frame + bf
noisy_input = x[:, :, fs:fe]
ar_state = {
"start_frame": current_start_frame,
"kv_caches": kv_caches,
"crossattn_caches": crossattn_caches,
}
transformer_options["ar_state"] = ar_state
for i in range(num_sigma_steps):
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
if callback is not None:
scaled_i = step_count * num_sigma_steps // total_real_steps
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
"sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
noisy_input = denoised
else:
sigma_next = sigmas[i + 1]
torch.manual_seed(seed + block_idx * 1000 + i)
fresh_noise = torch.randn_like(denoised)
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
step_count += 1
output[:, :, fs:fe] = noisy_input
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
current_start_frame += bf
finally:
transformer_options.pop("ar_state", None)
return output

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

276
comfy/ldm/wan/ar_model.py Normal file
View File

@ -0,0 +1,276 @@
"""
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
autoregressive (frame-by-frame) video generation via Causal Forcing.
Weight-compatible with the standard WanModel -- same layer names, same shapes.
The difference is purely in the forward pass: this model processes one temporal
block at a time and maintains a KV cache across blocks.
Reference: https://github.com/thu-ml/Causal-Forcing
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.wan.model import (
sinusoidal_embedding_1d,
repeat_e,
WanModel,
WanAttentionBlock,
)
import comfy.ldm.common_dit
import comfy.model_management
class CausalWanSelfAttention(nn.Module):
"""Self-attention with KV cache support for autoregressive inference."""
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qk_norm = qk_norm
self.eps = eps
ops = operation_settings.get("operations")
device = operation_settings.get("device")
dtype = operation_settings.get("dtype")
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
v = self.v(x).view(b, s, n, d)
if kv_cache is None:
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v.view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
else:
end = kv_cache["end"]
new_end = end + s
# Roped K and plain V go into cache
kv_cache["k"][:, end:new_end] = k
kv_cache["v"][:, end:new_end] = v
kv_cache["end"] = new_end
x = optimized_attention(
q.view(b, s, n * d),
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
return x
class CausalWanAttentionBlock(WanAttentionBlock):
"""Transformer block with KV-cached self-attention and cross-attention caching."""
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
operation_settings=operation_settings)
self.self_attn = CausalWanSelfAttention(
dim, num_heads, window_size, qk_norm, eps,
operation_settings=operation_settings)
def forward(self, x, e, freqs, context, context_img_len=257,
kv_cache=None, crossattn_cache=None, transformer_options={}):
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# Self-attention with optional KV cache
x = x.contiguous()
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# Cross-attention with optional caching
if crossattn_cache is not None and crossattn_cache.get("is_init"):
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
x_ca = optimized_attention(
q, crossattn_cache["k"], crossattn_cache["v"],
heads=self.num_heads, transformer_options=transformer_options)
x = x + self.cross_attn.o(x_ca)
else:
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if crossattn_cache is not None:
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
crossattn_cache["v"] = self.cross_attn.v(context)
crossattn_cache["is_init"] = True
# FFN
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class CausalWanModel(WanModel):
"""
Wan 2.1 diffusion backbone with causal KV-cache support.
Same weight structure as WanModel -- loads identical state dicts.
Adds forward_block() for frame-by-frame autoregressive inference.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None):
super().__init__(
model_type=model_type, patch_size=patch_size, text_len=text_len,
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
wan_attn_block_class=CausalWanAttentionBlock,
device=device, dtype=dtype, operations=operations)
def forward_block(self, x, timestep, context, start_frame,
kv_caches, crossattn_caches, clip_fea=None):
"""
Forward one temporal block for autoregressive inference.
Args:
x: [B, C, block_frames, H, W] input latent for the current block
timestep: [B, block_frames] per-frame timesteps
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
start_frame: temporal frame index for RoPE offset
kv_caches: list of per-layer KV cache dicts
crossattn_caches: list of per-layer cross-attention cache dicts
clip_fea: optional CLIP features for I2V
Returns:
flow_pred: [B, C_out, block_frames, H, W] flow prediction
"""
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
bs, c, t, h, w = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# Per-frame time embedding
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# Text embedding (reuses crossattn_cache after first block)
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
# RoPE for current block's temporal position
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
# Transformer blocks
for i, block in enumerate(self.blocks):
x = block(x, e=e0, freqs=freqs, context=context,
context_img_len=context_img_len,
kv_cache=kv_caches[i],
crossattn_cache=crossattn_caches[i])
# Head
x = self.head(x, e)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w]
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
"""Create fresh KV caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"end": 0,
})
return caches
def init_crossattn_caches(self, batch_size, device, dtype):
"""Create fresh cross-attention caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({"is_init": False})
return caches
def reset_kv_caches(self, kv_caches):
"""Reset KV caches to empty (reuse allocated memory)."""
for cache in kv_caches:
cache["end"] = 0
def reset_crossattn_caches(self, crossattn_caches):
"""Reset cross-attention caches."""
for cache in crossattn_caches:
cache["is_init"] = False
@property
def head_dim(self):
return self.dim // self.num_heads
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
ar_state = transformer_options.get("ar_state")
if ar_state is not None:
bs = x.shape[0]
block_frames = x.shape[2]
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
return self.forward_block(
x=x, timestep=t_per_frame, context=context,
start_frame=ar_state["start_frame"],
kv_caches=ar_state["kv_caches"],
crossattn_caches=ar_state["crossattn_caches"],
clip_fea=clip_fea,
)
return super().forward(x, timestep, context, clip_fea=clip_fea,
time_dim_concat=time_dim_concat,
transformer_options=transformer_options, **kwargs)

View File

@ -17,6 +17,7 @@
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@ -214,6 +215,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
@ -1360,6 +1366,13 @@ class WAN21(BaseModel):
return out
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
self.image_to_video = False
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -1175,6 +1176,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

65
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,65 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
xfer_dest = None
cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._prefetch = prefetch
continue
if not resident:
materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ]
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x
def to_dequant(tensor, dtype):
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
update_weight = signature is not None
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):
@ -1173,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher
import comfy.lora
@ -1271,6 +1272,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd):
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

View File

@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_CausalAR_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"causal_ar": True,
}
sampling_settings = {
"shift": 5.0,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.unet_config.pop("causal_ar", None)
def get_model(self, state_dict, prefix="", device=None):
return model_base.WAN21_CausalAR(self, device=device)
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1929,6 +1948,7 @@ models = [
ZImage,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
WAN21_T2V,
WAN21_I2V,
WAN21_FunControl2V,

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else:
transformer = TransformerBlock
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device
if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1)
if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = super().process_tokens(tokens, device)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional
from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
)
UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode):
@classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TopazImageEnhance,
TopazVideoEnhance,
TopazVideoEnhanceV2,
]

View File

@ -0,0 +1,84 @@
"""
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
"""
import torch
from typing_extensions import override
import comfy.model_management
import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class EmptyARVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyARVideoLatent",
category="latent/video",
inputs=[
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class SamplerARVideo(io.ComfyNode):
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
All AR-loop parameters are owned by this node so they live in the workflow.
Add new widgets here as the AR sampler grows new options.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerARVideo",
display_name="Sampler AR Video",
category="sampling/custom_sampling/samplers",
inputs=[
io.Int.Input(
"num_frame_per_block",
default=1, min=1, max=64,
tooltip="Frames per autoregressive block. 1 = framewise, "
"3 = chunkwise. Must match the checkpoint's training mode.",
),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, num_frame_per_block) -> io.NodeOutput:
extra_options = {
"num_frame_per_block": num_frame_per_block,
}
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyARVideoLatent,
SamplerARVideo,
]
async def comfy_entrypoint() -> ARVideoExtension:
return ARVideoExtension()

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension):

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:

View File

@ -1694,26 +1694,27 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
# This code is left here to handle animated webp which pyav does not support loading
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
@ -1728,25 +1729,15 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
@classmethod
def IS_CHANGED(s, image):
@ -2428,6 +2419,7 @@ async def init_builtin_extra_nodes():
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_ar_video.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.15
comfyui-workflow-templates==0.9.66
comfyui-workflow-templates==0.9.68
comfyui-embedded-docs==0.4.4
torch
torchsde