Compare commits

...

56 Commits

Author SHA1 Message Date
Yousef R. Gamaleldin
2bc68dbd8a
Merge dbf8f9dcf9 into 2e9d51680a 2026-01-08 12:50:22 +08:00
comfyanonymous
2e9d51680a ComfyUI version v0.8.2
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-07 23:50:02 -05:00
comfyanonymous
50d6e1caf4
Tweak ltxv vae mem estimation. (#11722) 2026-01-07 23:07:05 -05:00
comfyanonymous
ac12f77bed ComfyUI version v0.8.1 2026-01-07 22:10:08 -05:00
ComfyUI Wiki
fcd9a236b0
Update template to 0.7.69 (#11719) 2026-01-07 18:22:23 -08:00
comfyanonymous
21e8425087
Add warning for old pytorch. (#11718) 2026-01-07 21:07:26 -05:00
rattus
b6c79a648a
ops: Fix offloading with FP8MM performance (#11697)
This logic was checking comfy_cast_weights, and going straight to
to the forward_comfy_cast_weights implementation without
attempting to downscale input to fp8 in the event comfy_cast_weights
is set.

The main reason comfy_cast_weights would be set would be for async
offload, which is not a good reason to nix FP8MM.

So instead, and together the underlying exclusions for FP8MM which
are:

* having a weight_function (usually LowVramPatch)
* force_cast_weights (compute dtype override)
* the weight is not Quantized
* the input is already quantized
* the model or layer has MM explictily disabled.

If you get past all of those exclusions, quantize the input tensor.
Then hand the new input, quantized or not off to
forward_comfy_cast_weights to handle it. If the weight is offloaded
but input is quantized you will get an offloaded MM8.
2026-01-07 21:01:16 -05:00
comfyanonymous
25bc1b5b57
Add memory estimation function to ltxav text encoder. (#11716) 2026-01-07 20:11:22 -05:00
comfyanonymous
3cd19e99c1
Increase ltxav mem estimation by a bit. (#11715) 2026-01-07 20:04:56 -05:00
comfyanonymous
007b87e7ac
Bump required comfy-kitchen version. (#11714) 2026-01-07 19:48:47 -05:00
comfyanonymous
34751fe9f9
Lower ltxv text encoder vram use. (#11713) 2026-01-07 19:12:15 -05:00
Jukka Seppänen
1c705f7bfb
Add device selection for LTXAVTextEncoderLoader (#11700) 2026-01-07 18:39:59 -05:00
rattus
48e5ea1dfd
model_patcher: Remove confusing load stat (#11710)
If the loader passes 1e32 as the usable memory size, it means force
the full load. This happens with CPU loads and a few other misc cases.
Removing the confusing number and just leave the other details.
2026-01-07 18:39:20 -05:00
Yousef R. Gamaleldin
dbf8f9dcf9
Merge branch 'master' into seedvr2 2026-01-07 11:57:40 +02:00
Yousef Rafat
72ca18acc2 . 2026-01-04 20:32:38 +02:00
Yousef Rafat
f588e6c821 ruff 2026-01-04 20:30:24 +02:00
Yousef Rafat
0da072e098 Merge branch 'seedvr2' of https://github.com/yousef-rafat/ComfyUI into seedvr2 2026-01-04 19:17:00 +02:00
Yousef Rafat
31d358c78c rope, attetntion update | vae on cpu warning 2026-01-04 19:15:53 +02:00
Yousef R. Gamaleldin
4dd42ef1b7
Merge branch 'master' into seedvr2 2026-01-04 17:24:27 +02:00
Yousef R. Gamaleldin
02529c6d57
Merge branch 'master' into seedvr2 2025-12-31 20:20:32 +02:00
Yousef Rafat
49febe15c3 Merge branch 'seedvr2' of https://github.com/yousef-rafat/ComfyUI into seedvr2 2025-12-30 18:45:13 +02:00
Yousef Rafat
84fa155071 fixed manual vae loading 2025-12-30 18:44:57 +02:00
Jedrzej Kosinski
4691717340 Merge branch 'master' into seedvr2 2025-12-29 19:00:13 -08:00
Yousef Rafat
fadc7839cc ruff 2025-12-26 23:14:33 +02:00
Yousef Rafat
3039c7ba14 tile edge case handles by padding vid 2025-12-26 23:12:45 +02:00
Yousef Rafat
9b573da39b added other types of attention + compatibility
with images
2025-12-26 21:16:36 +02:00
Yousef Rafat
4d7012ecda . 2025-12-26 02:23:51 +02:00
Yousef Rafat
21bc67d7db final changes 2025-12-26 02:08:59 +02:00
Yousef Rafat
7b2e5ef0af outputs/speed/memory match custom node 2025-12-24 22:15:27 +02:00
Yousef Rafat
1afc2ed8e6 fixed the speed issue 2025-12-24 02:23:57 +02:00
Yousef Rafat
d41b1111eb removed print statement 2025-12-23 12:36:10 +02:00
Yousef Rafat
5b0c80a093 ruff 2025-12-23 12:35:00 +02:00
Yousef Rafat
e30298dda2 .. 2025-12-22 21:49:48 +02:00
Yousef Rafat
98b6bfcb71 revert file perm. 2025-12-22 21:46:40 +02:00
Yousef Rafat
fc5fabb629 . 2025-12-22 21:16:21 +02:00
Yousef Rafat
5db5da790f remove cfg cutoff node 2025-12-22 21:15:12 +02:00
Yousef Rafat
a4e9d071e8 video works 2025-12-22 18:12:46 +02:00
Yousef Rafat
4fe772fae9 improvements 2025-12-20 23:20:45 +02:00
Yousef Rafat
0d2044a778 ... 2025-12-19 20:28:09 +02:00
Yousef Rafat
7e62f8cc9f added var length attention and fixed the vae issue 2025-12-19 20:23:39 +02:00
Yousef Rafat
74621b9d86 . 2025-12-18 14:52:10 +02:00
Yousef Rafat
db74a27870 fix vae issue 2025-12-18 14:13:41 +02:00
Yousef Rafat
acb9a11c6f Merge branch 'seedvr2' of https://github.com/yousef-rafat/ComfyUI into seedvr2 2025-12-18 00:37:32 +02:00
Yousef Rafat
d9f71da998 works 2025-12-18 00:32:14 +02:00
Yousef R. Gamaleldin
183b377588
Merge branch 'master' into seedvr2 2025-12-17 00:39:06 +02:00
Yousef Rafat
ebd945ce3d vae fix 2025-12-17 00:09:38 +02:00
Yousef Rafat
58e7cea796 lora, 7b model, cfg 2025-12-13 19:48:57 +02:00
Yousef Rafat
768c9cedf8 .. 2025-12-12 20:51:40 +02:00
Yousef Rafat
d629c8f910 testing 2025-12-12 00:46:23 +02:00
Yousef Rafat
413ee3f687 . 2025-12-10 22:58:53 +02:00
Yousef Rafat
d12702ee0b fixed some issues 2025-12-09 23:54:56 +02:00
Yousef Rafat
f030b3afc8 mostly fixing mistakes 2025-12-09 00:16:17 +02:00
Yousef Rafat
44a5bf353a testing the model 2025-12-07 23:43:49 +02:00
Yousef Rafat
4b9332cc21 continue building nodes / testing vae 2025-12-07 21:41:14 +02:00
Yousef Rafat
041dbd6a8a add nodes 2025-12-07 01:00:08 +02:00
Yousef Rafat
08d93555d0 init 2025-12-06 23:18:10 +02:00
20 changed files with 4067 additions and 44 deletions

View File

@ -747,6 +747,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
latent_dimensions = 16
class ChromaRadiance(LatentFormat):
latent_channels = 3

View File

@ -19,9 +19,15 @@ if model_management.xformers_enabled():
import xformers.ops
SAGE_ATTENTION_IS_AVAILABLE = False
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VAR_LENGTH_AVAILABLE = True
except:
pass
except ImportError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
@ -39,7 +45,7 @@ except ImportError:
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled():
@ -87,7 +93,13 @@ def default(val, d):
return val
return d
def var_attn_arg(kwargs):
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
max_seqlen_q = kwargs.get("max_seqlen_q", None)
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@ -412,13 +424,14 @@ except:
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
disabled_xformers = False
if BROKEN_XFORMERS:
if b * heads > 65535:
if b * heads > 65535 and not var_length:
disabled_xformers = True
if not disabled_xformers:
@ -426,9 +439,27 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, var_length=var_length, **kwargs)
if skip_reshape:
if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q = q.view(1, total_tokens, heads, dim_head)
k = k.view(1, total_tokens, heads, dim_head)
v = v.view(1, total_tokens, heads, dim_head)
else:
if q.ndim == 3:
q = q.unsqueeze(0)
if k.ndim == 3:
k = k.unsqueeze(0)
if v.ndim == 3:
v = v.unsqueeze(0)
dim_head = q.shape[-1]
target_output_shape = (q.shape[1], -1)
b = 1
elif skip_reshape:
# b h k d -> b k h d
q, k, v = map(
lambda t: t.permute(0, 2, 1, 3),
@ -442,7 +473,11 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
(q, k, v),
)
if mask is not None:
if var_length:
cu_seqlens_q, _, _, _ = var_attn_arg(kwargs)
seq_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
mask = xformers.ops.BlockDiagonalMask.from_seqlens(seq_lens_q=seq_lens, seq_lens_k=seq_lens)
elif mask is not None:
# add a singleton batch dimension
if mask.ndim == 2:
mask = mask.unsqueeze(0)
@ -464,6 +499,8 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
if var_length:
return out.reshape(*target_output_shape)
if skip_output_reshape:
out = out.permute(0, 2, 1, 3)
else:
@ -481,7 +518,28 @@ else:
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
b = q.size(0)
dim_head = q.shape[-1]
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
mask = None
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
elif skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
@ -499,8 +557,10 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if SDP_BATCH_LIMIT >= b:
if SDP_BATCH_LIMIT >= b or var_length:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if var_length:
return out.contiguous().transpose(1, 2).values()
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -524,8 +584,19 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
exception_fallback = False
if skip_reshape:
if var_length:
if not skip_reshape:
total_tokens, hidden_dim = q.shape
dim_head = hidden_dim // heads
q, k, v = [t.view(total_tokens, heads, dim_head) for t in (q, k, v)]
b, _, dim_head = q.shape
# skips batched code
mask = None
tensor_layout = "VAR"
target_output_shape = (q.shape[0], -1)
elif skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
else:
@ -546,7 +617,14 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
mask = mask.unsqueeze(1)
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
if var_length and not SAGE_ATTENTION_VAR_LENGTH_AVAILABLE:
raise ValueError("Sage Attention two is required to run variable length attention.")
elif var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
sm_scale = 1.0 / (q.shape[-1] ** 0.5)
out = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, is_causal=False, sm_scale=sm_scale)
else:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
@ -556,7 +634,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, var_length=var_length, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@ -567,6 +645,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if skip_output_reshape:
out = out.transpose(1, 2)
else:
if var_length:
return out.view(*target_output_shape)
out = out.reshape(b, -1, heads * dim_head)
return out
@ -678,6 +758,15 @@ except AttributeError as error:
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
var_length = kwargs.get("var_length", False)
if var_length:
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
return flash_attn_varlen_func(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
dropout_p=0.0, softmax_scale=None, causal=False
)
if skip_reshape:
b, _, _, dim_head = q.shape
else:

View File

@ -13,13 +13,14 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos = False, downscale_freq_shift = 1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -30,11 +31,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

1431
comfy/ldm/seedvr/model.py Normal file

File diff suppressed because it is too large Load Diff

1936
comfy/ldm/seedvr/vae.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -47,6 +47,8 @@ import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
@ -815,6 +817,16 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)

View File

@ -447,6 +447,28 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"

View File

@ -718,6 +718,7 @@ class ModelPatcher:
continue
cast_weight = self.force_cast_weights
m.comfy_force_cast_weights = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
@ -790,11 +791,12 @@ class ModelPatcher:
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -654,29 +654,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
run_every_op()
input_shape = input.shape
tensor_3d = input.ndim == 3
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
reshaped_3d = False
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)):
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
if input.ndim != 2:
# Fall back to comfy_cast_weights for non-2D tensors
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs)
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
# dtype is now implicit in the layout class
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
output = self._forward(input, self.weight, self.bias)
output = self.forward_comfy_cast_weights(input)
# Reshape output back to 3D if input was 3D
if tensor_3d:
if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output

View File

@ -19,6 +19,7 @@ try:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():

0
comfy/samplers.py Executable file → Normal file
View File

View File

@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
@ -218,7 +219,7 @@ class CLIP:
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
@ -266,7 +267,7 @@ class CLIP:
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
@ -299,8 +300,11 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self):
model_management.load_model_gpu(self.patcher)
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
return self.patcher
def get_key_patches(self):
@ -309,7 +313,10 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
pass
else:
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -376,6 +383,17 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -476,13 +494,14 @@ class VAE:
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]

View File

@ -845,7 +845,7 @@ class LTXAV(LTXV):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.055 # TODO
self.memory_usage_factor = 0.061 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)
@ -1303,6 +1303,25 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def get_model(self, state_dict, prefix = "", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma):
unet_config = {
"image_model": "chroma_radiance",
@ -1551,6 +1570,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, SeedVR2]
models += [SVD_img2vid]

View File

@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module):
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
@ -118,6 +121,14 @@ class LTXAVTEModel(torch.nn.Module):
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):

View File

@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=["default", "cpu"],
)
],
outputs=[io.Clip.Output()],
@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)

View File

@ -0,0 +1,465 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import torch
import math
from einops import rearrange
import gc
import comfy.model_management
from comfy.utils import ProgressBar
import torch.nn.functional as F
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda, Normalize
from torchvision.transforms.functional import InterpolationMode
@torch.inference_mode()
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, encode=True):
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(vae_model.parameters()).dtype)
if x.ndim != 5:
x = x.unsqueeze(2)
b, c, d, h, w = x.shape
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
if encode:
ti_h, ti_w = tile_size
ov_h, ov_w = tile_overlap
target_d = (d + sf_t - 1) // sf_t
target_h = (h + sf_s - 1) // sf_s
target_w = (w + sf_s - 1) // sf_s
else:
ti_h = max(1, tile_size[0] // sf_s)
ti_w = max(1, tile_size[1] // sf_s)
ov_h = max(0, tile_overlap[0] // sf_s)
ov_w = max(0, tile_overlap[1] // sf_s)
target_d = d * sf_t
target_h = h * sf_s
target_w = w * sf_s
stride_h = max(1, ti_h - ov_h)
stride_w = max(1, ti_w - ov_w)
storage_device = vae_model.device
result = None
count = None
def run_temporal_chunks(spatial_tile):
chunk_results = []
t_dim_size = spatial_tile.shape[2]
if encode:
input_chunk = temporal_size
else:
input_chunk = max(1, temporal_size // sf_t)
for i in range(0, t_dim_size, input_chunk):
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
current_valid_len = t_chunk.shape[2]
pad_amount = 0
if current_valid_len < input_chunk:
pad_amount = input_chunk - current_valid_len
last_frame = t_chunk[:, :, -1:, :, :]
padding = last_frame.repeat(1, 1, pad_amount, 1, 1)
t_chunk = torch.cat([t_chunk, padding], dim=2)
t_chunk = t_chunk.contiguous()
if encode:
out = vae_model.encode(t_chunk)[0]
else:
out = vae_model.decode_(t_chunk)
if isinstance(out, (tuple, list)):
out = out[0]
if out.ndim == 4:
out = out.unsqueeze(2)
if pad_amount > 0:
if encode:
expected_valid_out = (current_valid_len + sf_t - 1) // sf_t
out = out[:, :, :expected_valid_out, :, :]
else:
expected_valid_out = current_valid_len * sf_t
out = out[:, :, :expected_valid_out, :, :]
chunk_results.append(out.to(storage_device))
return torch.cat(chunk_results, dim=2)
ramp_cache = {}
def get_ramp(steps):
if steps not in ramp_cache:
t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32)
ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi)
return ramp_cache[steps]
total_tiles = len(range(0, h, stride_h)) * len(range(0, w, stride_w))
bar = ProgressBar(total_tiles)
for y_idx in range(0, h, stride_h):
y_end = min(y_idx + ti_h, h)
for x_idx in range(0, w, stride_w):
x_end = min(x_idx + ti_w, w)
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
# Run VAE
tile_out = run_temporal_chunks(tile_x)
if result is None:
b_out, c_out = tile_out.shape[0], tile_out.shape[1]
result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32)
count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32)
if encode:
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
else:
ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3]
xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4]
cur_ov_h = max(0, min(ov_h, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w, tile_out.shape[4] // 2))
w_h = torch.ones((tile_out.shape[3],), device=storage_device)
w_w = torch.ones((tile_out.shape[4],), device=storage_device)
if cur_ov_h > 0:
r = get_ramp(cur_ov_h)
if y_idx > 0:
w_h[:cur_ov_h] = r
if y_end < h:
w_h[-cur_ov_h:] = 1.0 - r
if cur_ov_w > 0:
r = get_ramp(cur_ov_w)
if x_idx > 0:
w_w[:cur_ov_w] = r
if x_end < w:
w_w[-cur_ov_w:] = 1.0 - r
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
valid_d = min(tile_out.shape[2], result.shape[2])
tile_out = tile_out[:, :, :valid_d, :, :]
tile_out.mul_(final_weight)
result[:, :, :valid_d, ys:ye, xs:xe] += tile_out
count[:, :, :, ys:ye, xs:xe] += final_weight
del tile_out, final_weight, tile_x, w_h, w_w
bar.update(1)
result.div_(count.clamp(min=1e-6))
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
if x.shape[2] == 1 and sf_t == 1:
result = result.squeeze(2)
return result
def pad_video_temporal(videos: torch.Tensor, count: int = 0, temporal_dim: int = 1, prepend: bool = False):
t = videos.size(temporal_dim)
if count == 0 and not prepend:
if t % 4 == 1:
return videos
count = ((t - 1) // 4 + 1) * 4 + 1 - t
if count <= 0:
return videos
def select(start, end):
return videos[start:end] if temporal_dim == 0 else videos[:, start:end]
if count >= t:
repeat_count = count - t + 1
last = select(-1, None)
if temporal_dim == 0:
repeated = last.repeat(repeat_count, 1, 1, 1)
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:0]
else:
repeated = last.expand(-1, repeat_count, -1, -1).contiguous()
reversed_frames = select(1, None).flip(temporal_dim) if t > 1 else last[:, :0]
return torch.cat([repeated, reversed_frames, videos] if prepend else
[videos, reversed_frames, repeated], dim=temporal_dim)
if prepend:
reversed_frames = select(1, count+1).flip(temporal_dim)
else:
reversed_frames = select(-count-1, -1).flip(temporal_dim)
return torch.cat([reversed_frames, videos] if prepend else
[videos, reversed_frames], dim=temporal_dim)
def clear_vae_memory(vae_model):
for module in vae_model.modules():
if hasattr(module, "memory"):
module.memory = None
gc.collect()
torch.cuda.empty_cache()
def expand_dims(tensor, ndim):
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
return tensor.reshape(shape)
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def timestep_transform(timesteps, latents_shapes):
vt = 4
vs = 8
frames = (latents_shapes[:, 0] - 1) * vt + 1
heights = latents_shapes[:, 1] * vs
widths = latents_shapes[:, 2] * vs
# Compute shift factor.
def get_lin_function(x1, y1, x2, y2):
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
shift = torch.where(
frames > 1,
vid_shift_fn(heights * widths * frames),
img_shift_fn(heights * widths),
).to(timesteps.device)
# Shift timesteps.
T = 1000.0
timesteps = timesteps / T
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
timesteps = timesteps * T
return timesteps
def inter(x_0, x_T, t):
t = expand_dims(t, x_0.ndim)
T = 1000.0
B = lambda t: t / T
A = lambda t: 1 - (t / T)
return A(t) * x_0 + B(t) * x_T
def area_resize(image, max_area):
height, width = image.shape[-2:]
scale = math.sqrt(max_area / (height * width))
resized_height, resized_width = round(height * scale), round(width * scale)
return TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=InterpolationMode.BICUBIC,
)
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
pad_height = (height_factor - (height % height_factor)) % height_factor
pad_width = (width_factor - (width % width_factor)) % width_factor
if pad_height == 0 and pad_width == 0:
return image
if isinstance(image, torch.Tensor):
padding = (0, pad_width, 0, pad_height)
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
return image
def cut_videos(videos):
t = videos.size(1)
if t == 1:
return videos
if t <= 4 :
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 - ((t - 1) % (4))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
return videos
def side_resize(image, size):
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
return resized
class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id = "SeedVR2InputProcessing",
category="image/video",
inputs = [
io.Image.Input("images"),
io.Vae.Input("vae"),
io.Int.Input("resolution", default = 1280, min = 120), # just non-zero value
io.Int.Input("spatial_tile_size", default = 512, min = 1),
io.Int.Input("spatial_overlap", default = 64, min = 1),
io.Int.Input("temporal_tile_size", default=5, min=1, max=16384, step=4),
io.Boolean.Input("enable_tiling", default=False),
],
outputs = [
io.Latent.Output("vae_conditioning")
]
)
@classmethod
def execute(cls, images, vae, resolution, spatial_tile_size, temporal_tile_size, spatial_overlap, enable_tiling):
comfy.model_management.load_models_gpu([vae.patcher])
vae_model = vae.first_stage_model
scale = 0.9152
shift = 0
if images.dim() != 5: # add the t dim
images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5)
images = side_resize(images, resolution)
images = clip(images)
o_h, o_w = images.shape[-2:]
images = div_pad(images, (16, 16))
images = normalize(images)
_, _, new_h, new_w = images.shape
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images)
images = rearrange(images, "b t c h w -> b c t h w")
# in case users a non-compatiable number for tiling
def make_divisible(val, divisor):
return max(divisor, round(val / divisor) * divisor)
spatial_tile_size = make_divisible(spatial_tile_size, 32)
spatial_overlap = make_divisible(spatial_overlap, 32)
if spatial_overlap >= spatial_tile_size:
spatial_overlap = max(0, spatial_tile_size - 8)
args = {"tile_size": (spatial_tile_size, spatial_tile_size), "tile_overlap": (spatial_overlap, spatial_overlap),
"temporal_size":temporal_tile_size}
if enable_tiling:
latent = tiled_vae(images, vae_model, encode=True, **args)
else:
latent = vae_model.encode(images, orig_dims = [o_h, o_w])[0]
clear_vae_memory(vae_model)
#images = images.to(offload_device)
#vae_model = vae_model.to(offload_device)
vae_model.img_dims = [o_h, o_w]
args["enable_tiling"] = enable_tiling
vae_model.tiled_args = args
vae_model.original_image_video = images
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent
latent = rearrange(latent, "b c ... -> b ... c")
latent = (latent - shift) * scale
return io.NodeOutput({"samples": latent})
class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
category="image/video",
inputs=[
io.Latent.Input("vae_conditioning"),
io.Model.Input("model"),
],
outputs=[io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"),
io.Latent.Output(display_name = "latent")],
)
@classmethod
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
device = vae_conditioning.device
model = model.model.diffusion_model
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = torch.randn_like(vae_conditioning).to(device)
aug_noises = noises * 0.1 + aug_noises * 0.05
cond_noise_scale = 0.0
t = (
torch.tensor([1000.0])
* cond_noise_scale
).to(device)
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] # avoid batch dim
t = timestep_transform(t, shape)
cond = inter(vae_conditioning, aug_noises, t)
condition = torch.stack([get_conditions(noise, c) for noise, c in zip(noises, cond)])
condition = condition.movedim(-1, 1)
noises = noises.movedim(-1, 1)
pos_shape = pos_cond.shape[0]
neg_shape = neg_cond.shape[0]
diff = abs(pos_shape - neg_shape)
if pos_shape > neg_shape:
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff))
noises = rearrange(noises, "b c t h w -> b (c t) h w")
condition = rearrange(condition, "b c t h w -> b (c t) h w")
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(positive, negative, {"samples": noises})
class SeedVRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SeedVR2Conditioning,
SeedVR2InputProcessing
]
async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.8.0"
__version__ = "0.8.2"

View File

@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_seedvr.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_chroma_radiance.py",

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.8.0"
version = "0.8.2"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.35.9
comfyui-workflow-templates==0.7.67
comfyui-workflow-templates==0.7.69
comfyui-embedded-docs==0.3.1
torch
torchsde
@ -21,7 +21,7 @@ psutil
alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.3
comfy-kitchen>=0.2.5
#non essential dependencies:
kornia>=0.7.1