diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py
index 46ef21c95..16998af94 100644
--- a/comfy/audio_encoders/audio_encoders.py
+++ b/comfy/audio_encoders/audio_encoders.py
@@ -25,11 +25,11 @@ class AudioEncoderModel():
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
- self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+ self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 1716c3de7..63daca861 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
+ DynamicVRAM = "dynamic_vram"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
@@ -257,3 +258,6 @@ elif args.fast == []:
# '--fast' is provided with a list of performance features, use that list
else:
args.fast = set(args.fast)
+
+def enables_dynamic_vram():
+ return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index b28bf636c..1691fca81 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -47,10 +47,10 @@ class ClipVisionModel():
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
- self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+ self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=False)
+ return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 0b5e30f52..9e1e704e0 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -203,7 +203,7 @@ class ControlNet(ControlBase):
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
- self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
+ self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
self.compression_ratio = compression_ratio
self.global_average_pooling = global_average_pooling
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 0949dee44..c0c51d51a 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1,11 +1,12 @@
import math
+import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
-from tqdm.auto import trange, tqdm
+from tqdm.auto import trange as trange_, tqdm
from . import utils
from . import deis
@@ -13,6 +14,36 @@ from . import sa_solver
import comfy.model_patcher
import comfy.model_sampling
+import comfy.memory_management
+
+
+def trange(*args, **kwargs):
+ if comfy.memory_management.aimdo_allocator is None:
+ return trange_(*args, **kwargs)
+
+ pbar = trange_(*args, **kwargs, smoothing=1.0)
+ pbar._i = 0
+ pbar.set_postfix_str(" Model Initializing ... ")
+
+ _update = pbar.update
+
+ def warmup_update(n=1):
+ pbar._i += 1
+ if pbar._i == 1:
+ pbar.i1_time = time.time()
+ pbar.set_postfix_str(" Model Initialization complete! ")
+ elif pbar._i == 2:
+ #bring forward the effective start time based the the diff between first and second iteration
+ #to attempt to remove load overhead from the final step rate estimate.
+ pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
+ pbar.set_postfix_str("")
+
+ _update(n)
+
+ pbar.update = warmup_update
+ return pbar
+
+
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 4b3a3798c..f59999af6 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -755,6 +755,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
+class ACEAudio15(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 1
+
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py
new file mode 100644
index 000000000..d90549658
--- /dev/null
+++ b/comfy/ldm/ace/ace_step15.py
@@ -0,0 +1,1093 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import itertools
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+from comfy.ldm.flux.layers import timestep_embedding
+
+def get_layer_class(operations, layer_name):
+ if operations is not None and hasattr(operations, layer_name):
+ return getattr(operations, layer_name)
+ return getattr(nn, layer_name)
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dim = dim
+ self.base = base
+ self.max_position_embeddings = max_position_embeddings
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len, x.device, x.dtype)
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
+ self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device),
+ )
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+ cos = cos.unsqueeze(0).unsqueeze(0)
+ sin = sin.unsqueeze(0).unsqueeze(0)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+class MLP(nn.Module):
+ def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
+ self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
+ self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device)
+ self.act1 = nn.SiLU()
+ self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device)
+ self.in_channels = in_channels
+ self.act2 = nn.SiLU()
+ self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device)
+ self.scale = scale
+
+ def forward(self, t, dtype=None):
+ t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale)
+ temb = self.linear_1(t_freq.to(dtype=dtype))
+ temb = self.act1(temb)
+ temb = self.linear_2(temb)
+ timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1])
+ return temb, timestep_proj
+
+class AceStepAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ rms_norm_eps=1e-6,
+ is_cross_attention=False,
+ sliding_window=None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.is_cross_attention = is_cross_attention
+ self.sliding_window = sliding_window
+
+ Linear = get_layer_class(operations, "Linear")
+
+ self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device)
+
+ self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ position_embeddings=None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ query_states = self.q_norm(query_states)
+ query_states = query_states.transpose(1, 2)
+
+ if self.is_cross_attention and encoder_hidden_states is not None:
+ bsz_enc, kv_len, _ = encoder_hidden_states.size()
+ key_states = self.k_proj(encoder_hidden_states)
+ value_states = self.v_proj(encoder_hidden_states)
+
+ key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+ value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
+
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ else:
+ kv_len = q_len
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
+
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ n_rep = self.num_heads // self.num_kv_heads
+ if n_rep > 1:
+ key_states = key_states.repeat_interleave(n_rep, dim=1)
+ value_states = value_states.repeat_interleave(n_rep, dim=1)
+
+ attn_bias = None
+ if self.sliding_window is not None and not self.is_cross_attention:
+ indices = torch.arange(q_len, device=query_states.device)
+ diff = indices.unsqueeze(1) - indices.unsqueeze(0)
+ in_window = torch.abs(diff) <= self.sliding_window
+
+ window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype)
+ min_value = torch.finfo(query_states.dtype).min
+ window_bias.masked_fill_(~in_window, min_value)
+
+ window_bias = window_bias.unsqueeze(0).unsqueeze(0)
+
+ if attn_bias is not None:
+ if attn_bias.dtype == torch.bool:
+ base_bias = torch.zeros_like(window_bias)
+ base_bias.masked_fill_(~attn_bias, min_value)
+ attn_bias = base_bias + window_bias
+ else:
+ attn_bias = attn_bias + window_bias
+ else:
+ attn_bias = window_bias
+
+ attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+class AceStepDiTLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ layer_type="full_attention",
+ sliding_window=128,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self_attn_window = sliding_window if layer_type == "sliding_attention" else None
+
+ self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.self_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=False, sliding_window=self_attn_window,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.cross_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=True, dtype=dtype, device=device, operations=operations
+ )
+
+ self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
+
+ self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states,
+ temb,
+ encoder_hidden_states,
+ position_embeddings,
+ attention_mask=None,
+ encoder_attention_mask=None
+ ):
+ modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1)
+
+ norm_hidden = self.self_attn_norm(hidden_states)
+ norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa
+
+ attn_out = self.self_attn(
+ norm_hidden,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+ hidden_states = hidden_states + attn_out * gate_msa
+
+ norm_hidden = self.cross_attn_norm(hidden_states)
+ attn_out = self.cross_attn(
+ norm_hidden,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask
+ )
+ hidden_states = hidden_states + attn_out
+
+ norm_hidden = self.mlp_norm(hidden_states)
+ norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa
+
+ mlp_out = self.mlp(norm_hidden)
+ hidden_states = hidden_states + mlp_out * c_gate_msa
+
+ return hidden_states
+
+class AceStepEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.self_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=False, dtype=dtype, device=device, operations=operations
+ )
+ self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, hidden_states, position_embeddings, attention_mask=None):
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+class AceStepLyricEncoder(nn.Module):
+ def __init__(
+ self,
+ text_hidden_dim,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, inputs_embeds, attention_mask=None):
+ hidden_states = self.embed_tokens(inputs_embeds)
+ seq_len = hidden_states.shape[1]
+ cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
+ position_embeddings = (cos, sin)
+
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+class AceStepTimbreEncoder(nn.Module):
+ def __init__(
+ self,
+ timbre_hidden_dim,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+ self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
+
+ def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
+ N, d = timbre_embs_packed.shape
+ device = timbre_embs_packed.device
+ B = N
+ counts = torch.bincount(refer_audio_order_mask, minlength=B)
+ max_count = counts.max().item()
+
+ sorted_indices = torch.argsort(
+ refer_audio_order_mask * N + torch.arange(N, device=device),
+ stable=True
+ )
+ sorted_batch_ids = refer_audio_order_mask[sorted_indices]
+
+ positions = torch.arange(N, device=device)
+ batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
+ positions_in_sorted = positions - batch_starts[sorted_batch_ids]
+
+ inverse_indices = torch.empty_like(sorted_indices)
+ inverse_indices[sorted_indices] = torch.arange(N, device=device)
+ positions_in_batch = positions_in_sorted[inverse_indices]
+
+ indices_2d = refer_audio_order_mask * max_count + positions_in_batch
+ one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype)
+
+ timbre_embs_flat = one_hot.t() @ timbre_embs_packed
+ timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d)
+
+ mask_flat = (one_hot.sum(dim=0) > 0).long()
+ new_mask = mask_flat.view(B, max_count)
+ return timbre_embs_unpack, new_mask
+
+ def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None):
+ hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
+ if hidden_states.dim() == 2:
+ hidden_states = hidden_states.unsqueeze(0)
+
+ seq_len = hidden_states.shape[1]
+ cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
+
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=(cos, sin),
+ attention_mask=attention_mask
+ )
+ hidden_states = self.norm(hidden_states)
+
+ flat_states = hidden_states[:, 0, :]
+ unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask)
+ return unpacked_embs, unpacked_mask
+
+
+def pack_sequences(hidden1, hidden2, mask1, mask2):
+ hidden_cat = torch.cat([hidden1, hidden2], dim=1)
+ B, L, D = hidden_cat.shape
+
+ if mask1 is not None and mask2 is not None:
+ mask_cat = torch.cat([mask1, mask2], dim=1)
+ sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
+ gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D)
+ hidden_sorted = torch.gather(hidden_cat, 1, gather_idx)
+ lengths = mask_cat.sum(dim=1)
+ new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
+ else:
+ new_mask = None
+ hidden_sorted = hidden_cat
+
+ return hidden_sorted, new_mask
+
+class AceStepConditionEncoder(nn.Module):
+ def __init__(
+ self,
+ text_hidden_dim,
+ timbre_hidden_dim,
+ hidden_size,
+ num_lyric_layers,
+ num_timbre_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device)
+
+ self.lyric_encoder = AceStepLyricEncoder(
+ text_hidden_dim=text_hidden_dim,
+ hidden_size=hidden_size,
+ num_layers=num_lyric_layers,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ intermediate_size=intermediate_size,
+ rms_norm_eps=rms_norm_eps,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.timbre_encoder = AceStepTimbreEncoder(
+ timbre_hidden_dim=timbre_hidden_dim,
+ hidden_size=hidden_size,
+ num_layers=num_timbre_layers,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ intermediate_size=intermediate_size,
+ rms_norm_eps=rms_norm_eps,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ def forward(
+ self,
+ text_hidden_states=None,
+ text_attention_mask=None,
+ lyric_hidden_states=None,
+ lyric_attention_mask=None,
+ refer_audio_acoustic_hidden_states_packed=None,
+ refer_audio_order_mask=None
+ ):
+ text_emb = self.text_projector(text_hidden_states)
+
+ lyric_emb = self.lyric_encoder(
+ inputs_embeds=lyric_hidden_states,
+ attention_mask=lyric_attention_mask
+ )
+
+ timbre_emb, timbre_mask = self.timbre_encoder(
+ refer_audio_acoustic_hidden_states_packed,
+ refer_audio_order_mask
+ )
+
+ merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask)
+ final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask)
+
+ return final_emb, final_mask
+
+# --------------------------------------------------------------------------------
+# Main Diffusion Model (DiT)
+# --------------------------------------------------------------------------------
+
+class AceStepDiTModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ patch_size,
+ audio_acoustic_hidden_dim,
+ layer_types=None,
+ sliding_window=128,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ Conv1d = get_layer_class(operations, "Conv1d")
+ ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d")
+ Linear = get_layer_class(operations, "Linear")
+
+ self.proj_in = nn.Sequential(
+ nn.Identity(),
+ Conv1d(
+ in_channels, hidden_size, kernel_size=patch_size, stride=patch_size,
+ dtype=dtype, device=device))
+
+ self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
+ self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
+ self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+
+ if layer_types is None:
+ layer_types = ["full_attention"] * num_layers
+
+ if len(layer_types) < num_layers:
+ layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers))
+
+ self.layers = nn.ModuleList([
+ AceStepDiTLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ layer_type=layer_types[i],
+ sliding_window=sliding_window,
+ dtype=dtype, device=device, operations=operations
+ ) for i in range(num_layers)
+ ])
+
+ self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.proj_out = nn.Sequential(
+ nn.Identity(),
+ ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device)
+ )
+
+ self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states,
+ timestep,
+ timestep_r,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ context_latents
+ ):
+ temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype)
+ temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype)
+ temb = temb_t + temb_r
+ timestep_proj = proj_t + proj_r
+
+ x = torch.cat([context_latents, hidden_states], dim=-1)
+ original_seq_len = x.shape[1]
+
+ pad_length = 0
+ if x.shape[1] % self.patch_size != 0:
+ pad_length = self.patch_size - (x.shape[1] % self.patch_size)
+ x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0)
+
+ x = x.transpose(1, 2)
+ x = self.proj_in(x)
+ x = x.transpose(1, 2)
+
+ encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
+
+ seq_len = x.shape[1]
+ cos, sin = self.rotary_emb(x, seq_len=seq_len)
+
+ for layer in self.layers:
+ x = layer(
+ hidden_states=x,
+ temb=timestep_proj,
+ encoder_hidden_states=encoder_hidden_states,
+ position_embeddings=(cos, sin),
+ attention_mask=None,
+ encoder_attention_mask=None
+ )
+
+ shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+ x = self.norm_out(x) * (1 + scale) + shift
+
+ x = x.transpose(1, 2)
+ x = self.proj_out(x)
+ x = x.transpose(1, 2)
+
+ x = x[:, :original_seq_len, :]
+ return x
+
+
+class AttentionPooler(nn.Module):
+ def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
+ self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, x):
+ B, T, P, D = x.shape
+ x = self.embed_tokens(x)
+ special = self.special_token.expand(B, T, 1, -1)
+ x = torch.cat([special, x], dim=2)
+ x = x.view(B * T, P + 1, D)
+
+ cos, sin = self.rotary_emb(x, seq_len=P + 1)
+ for layer in self.layers:
+ x = layer(x, (cos, sin))
+
+ x = self.norm(x)
+ return x[:, 0, :].view(B, T, D)
+
+
+class FSQ(nn.Module):
+ def __init__(
+ self,
+ levels,
+ dim=None,
+ device=None,
+ dtype=None,
+ operations=None
+ ):
+ super().__init__()
+
+ _levels = torch.tensor(levels, dtype=torch.int32, device=device)
+ self.register_buffer('_levels', _levels, persistent=False)
+
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0)
+ self.register_buffer('_basis', _basis, persistent=False)
+
+ self.codebook_dim = len(levels)
+ self.dim = dim if dim is not None else self.codebook_dim
+
+ requires_projection = self.dim != self.codebook_dim
+ if requires_projection:
+ self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype)
+ self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype)
+ else:
+ self.project_in = nn.Identity()
+ self.project_out = nn.Identity()
+
+ self.codebook_size = self._levels.prod().item()
+
+ indices = torch.arange(self.codebook_size, device=device)
+ implicit_codebook = self._indices_to_codes(indices)
+
+ if dtype is not None:
+ implicit_codebook = implicit_codebook.to(dtype)
+
+ self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
+
+ def bound(self, z):
+ levels_minus_1 = (self._levels - 1).to(z.dtype)
+ scale = 2. / levels_minus_1
+ bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
+
+ zhat = bracket.floor()
+ bracket_ste = bracket + (zhat - bracket).detach()
+
+ return scale * bracket_ste - 1.
+
+ def _indices_to_codes(self, indices):
+ indices = indices.unsqueeze(-1)
+ codes_non_centered = (indices // self._basis) % self._levels
+ return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
+
+ def codes_to_indices(self, zhat):
+ zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
+ return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
+
+ def forward(self, z):
+ orig_dtype = z.dtype
+ z = self.project_in(z)
+
+ codes = self.bound(z)
+ indices = self.codes_to_indices(codes)
+
+ out = self.project_out(codes)
+ return out.to(orig_dtype), indices
+
+
+class ResidualFSQ(nn.Module):
+ def __init__(
+ self,
+ levels,
+ num_quantizers,
+ dim=None,
+ bound_hard_clamp=True,
+ device=None,
+ dtype=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+
+ codebook_dim = len(levels)
+ dim = dim if dim is not None else codebook_dim
+
+ requires_projection = codebook_dim != dim
+ if requires_projection:
+ self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype)
+ self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype)
+ else:
+ self.project_in = nn.Identity()
+ self.project_out = nn.Identity()
+
+ self.layers = nn.ModuleList()
+ levels_tensor = torch.tensor(levels, device=device)
+ scales = []
+
+ for ind in range(num_quantizers):
+ scale_val = levels_tensor.float() ** -ind
+ scales.append(scale_val)
+
+ self.layers.append(FSQ(
+ levels=levels,
+ dim=codebook_dim,
+ device=device,
+ dtype=dtype,
+ operations=operations
+ ))
+
+ scales_tensor = torch.stack(scales)
+ if dtype is not None:
+ scales_tensor = scales_tensor.to(dtype)
+ self.register_buffer('scales', scales_tensor, persistent=False)
+
+ if bound_hard_clamp:
+ val = 1 + (1 / (levels_tensor.float() - 1))
+ if dtype is not None:
+ val = val.to(dtype)
+ self.register_buffer('soft_clamp_input_value', val, persistent=False)
+
+ def get_output_from_indices(self, indices, dtype=torch.float32):
+ if indices.dim() == 2:
+ indices = indices.unsqueeze(-1)
+
+ all_codes = []
+ for i, layer in enumerate(self.layers):
+ idx = indices[..., i].long()
+ codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype))
+ all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype))
+
+ codes_summed = torch.stack(all_codes).sum(dim=0)
+ return self.project_out(codes_summed)
+
+ def forward(self, x):
+ x = self.project_in(x)
+
+ if hasattr(self, 'soft_clamp_input_value'):
+ sc_val = self.soft_clamp_input_value.to(x.dtype)
+ x = (x / sc_val).tanh() * sc_val
+
+ quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
+ residual = x
+ all_indices = []
+
+ for layer, scale in zip(self.layers, self.scales):
+ scale = scale.to(residual.dtype)
+
+ quantized, indices = layer(residual / scale)
+ quantized = quantized * scale
+
+ residual = residual - quantized.detach()
+ quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+
+ quantized_out = self.project_out(quantized_out)
+ all_indices = torch.stack(all_indices, dim=-1)
+
+ return quantized_out, all_indices
+
+
+class AceStepAudioTokenizer(nn.Module):
+ def __init__(
+ self,
+ audio_acoustic_hidden_dim,
+ hidden_size,
+ pool_window_size,
+ fsq_dim,
+ fsq_levels,
+ fsq_input_num_quantizers,
+ num_layers,
+ head_dim,
+ rms_norm_eps,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.attention_pooler = AttentionPooler(
+ hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations
+ )
+ self.pool_window_size = pool_window_size
+ self.fsq_dim = fsq_dim
+ self.quantizer = ResidualFSQ(
+ dim=fsq_dim,
+ levels=fsq_levels,
+ num_quantizers=fsq_input_num_quantizers,
+ bound_hard_clamp=True,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.audio_acoustic_proj(hidden_states)
+ hidden_states = self.attention_pooler(hidden_states)
+ quantized, indices = self.quantizer(hidden_states)
+ return quantized, indices
+
+ def tokenize(self, x):
+ B, T, D = x.shape
+ P = self.pool_window_size
+
+ if T % P != 0:
+ pad = P - (T % P)
+ x = F.pad(x, (0, 0, 0, pad))
+ T = x.shape[1]
+
+ T_patch = T // P
+ x = x.view(B, T_patch, P, D)
+
+ quantized, indices = self.forward(x)
+ return quantized, indices
+
+
+class AudioTokenDetokenizer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ pool_window_size,
+ audio_acoustic_hidden_dim,
+ num_layers,
+ head_dim,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.pool_window_size = pool_window_size
+ self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+ self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device))
+ self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
+ self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device)
+
+ def forward(self, x):
+ B, T, D = x.shape
+ x = self.embed_tokens(x)
+ x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
+ x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype)
+ x = x.view(B * T, self.pool_window_size, D)
+
+ cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size)
+ for layer in self.layers:
+ x = layer(x, (cos, sin))
+
+ x = self.norm(x)
+ x = self.proj_out(x)
+ return x.view(B, T * self.pool_window_size, -1)
+
+
+class AceStepConditionGenerationModel(nn.Module):
+ def __init__(
+ self,
+ in_channels=192,
+ hidden_size=2048,
+ text_hidden_dim=1024,
+ timbre_hidden_dim=64,
+ audio_acoustic_hidden_dim=64,
+ num_dit_layers=24,
+ num_lyric_layers=8,
+ num_timbre_layers=4,
+ num_tokenizer_layers=2,
+ num_heads=16,
+ num_kv_heads=8,
+ head_dim=128,
+ intermediate_size=6144,
+ patch_size=2,
+ pool_window_size=5,
+ rms_norm_eps=1e-06,
+ timestep_mu=-0.4,
+ timestep_sigma=1.0,
+ data_proportion=0.5,
+ sliding_window=128,
+ layer_types=None,
+ fsq_dim=2048,
+ fsq_levels=[8, 8, 8, 5, 5, 5],
+ fsq_input_num_quantizers=1,
+ audio_model=None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.timestep_mu = timestep_mu
+ self.timestep_sigma = timestep_sigma
+ self.data_proportion = data_proportion
+ self.pool_window_size = pool_window_size
+
+ if layer_types is None:
+ layer_types = []
+ for i in range(num_dit_layers):
+ layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention")
+
+ self.decoder = AceStepDiTModel(
+ in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
+ intermediate_size, patch_size, audio_acoustic_hidden_dim,
+ layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.encoder = AceStepConditionEncoder(
+ text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
+ num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.tokenizer = AceStepAudioTokenizer(
+ audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.detokenizer = AudioTokenDetokenizer(
+ hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
+
+ def prepare_condition(
+ self,
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
+ src_latents, chunk_masks, is_covers,
+ precomputed_lm_hints_25Hz=None,
+ audio_codes=None
+ ):
+ encoder_hidden, encoder_mask = self.encoder(
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask
+ )
+
+ if precomputed_lm_hints_25Hz is not None:
+ lm_hints = precomputed_lm_hints_25Hz
+ else:
+ if audio_codes is not None:
+ if audio_codes.shape[1] * 5 < src_latents.shape[1]:
+ audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
+ lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
+ else:
+ assert False
+ # TODO ?
+
+ lm_hints = self.detokenizer(lm_hints_5Hz)
+
+ lm_hints = lm_hints[:, :src_latents.shape[1], :]
+ if is_covers is None:
+ src_latents = lm_hints
+ else:
+ src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
+
+ context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
+
+ return encoder_hidden, encoder_mask, context_latents
+
+ def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
+ text_attention_mask = None
+ lyric_attention_mask = None
+ refer_audio_order_mask = None
+ attention_mask = None
+ chunk_masks = None
+ is_covers = None
+ src_latents = None
+ precomputed_lm_hints_25Hz = None
+ lyric_hidden_states = lyric_embed
+ text_hidden_states = context
+ refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2)
+
+ x = x.movedim(-1, -2)
+
+ if refer_audio_order_mask is None:
+ refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
+
+ if src_latents is None and is_covers is None:
+ src_latents = x
+
+ if chunk_masks is None:
+ chunk_masks = torch.ones_like(x)
+
+ enc_hidden, enc_mask, context_latents = self.prepare_condition(
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
+ src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
+ )
+
+ out = self.decoder(hidden_states=x,
+ timestep=timestep,
+ timestep_r=timestep,
+ attention_mask=attention_mask,
+ encoder_hidden_states=enc_hidden,
+ encoder_attention_mask=enc_mask,
+ context_latents=context_latents
+ )
+
+ return out.movedim(-1, -2)
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index 51b6d1da8..1f68144e2 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -109,10 +109,10 @@ class HunyuanVideo15SRModel():
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
- self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+ self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
- return self.model.load_state_dict(sd, strict=True)
+ return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
diff --git a/comfy/lora.py b/comfy/lora.py
index 7b31d055c..44030bcab 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
+ if isinstance(model, comfy.model_base.ACEStep15):
+ for k in sdk:
+ if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
+ key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
+
return key_map
diff --git a/comfy/memory_management.py b/comfy/memory_management.py
new file mode 100644
index 000000000..858bd4cc7
--- /dev/null
+++ b/comfy/memory_management.py
@@ -0,0 +1,81 @@
+import math
+import torch
+from typing import NamedTuple
+
+from comfy.quant_ops import QuantizedTensor
+
+class TensorGeometry(NamedTuple):
+ shape: any
+ dtype: torch.dtype
+
+ def element_size(self):
+ info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype)
+ return info.bits // 8
+
+ def numel(self):
+ return math.prod(self.shape)
+
+def tensors_to_geometries(tensors, dtype=None):
+ geometries = []
+ for t in tensors:
+ if t is None or isinstance(t, QuantizedTensor):
+ geometries.append(t)
+ continue
+ tdtype = t.dtype
+ if hasattr(t, "_model_dtype"):
+ tdtype = t._model_dtype
+ if dtype is not None:
+ tdtype = dtype
+ geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype))
+ return geometries
+
+def vram_aligned_size(tensor):
+ if isinstance(tensor, list):
+ return sum([vram_aligned_size(t) for t in tensor])
+
+ if isinstance(tensor, QuantizedTensor):
+ inner_tensors, _ = tensor.__tensor_flatten__()
+ return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ])
+
+ if tensor is None:
+ return 0
+
+ size = tensor.numel() * tensor.element_size()
+ aligment_req = 1024
+ return (size + aligment_req - 1) // aligment_req * aligment_req
+
+def interpret_gathered_like(tensors, gathered):
+ offset = 0
+ dest_views = []
+
+ if gathered.dim() != 1 or gathered.element_size() != 1:
+ raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})")
+
+ for tensor in tensors:
+
+ if tensor is None:
+ dest_views.append(None)
+ continue
+
+ if isinstance(tensor, QuantizedTensor):
+ inner_tensors, qt_ctx = tensor.__tensor_flatten__()
+ templates = { attr: getattr(tensor, attr) for attr in inner_tensors }
+ else:
+ templates = { "data": tensor }
+
+ actuals = {}
+ for attr, template in templates.items():
+ size = template.numel() * template.element_size()
+ if offset + size > gathered.numel():
+ raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ")
+ actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape)
+ offset += vram_aligned_size(template)
+
+ if isinstance(tensor, QuantizedTensor):
+ dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0))
+ else:
+ dest_views.append(actuals["data"])
+
+ return dest_views
+
+aimdo_allocator = None
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 66e52864d..89944548c 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
+import comfy.ldm.ace.ace_step15
import comfy.model_management
import comfy.patcher_extension
@@ -149,6 +150,8 @@ class BaseModel(torch.nn.Module):
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
+ comfy.model_management.archive_model_dtypes(self.diffusion_model)
+
self.adm_channels = unet_config.get("adm_in_channels", None)
if self.adm_channels is None:
self.adm_channels = 0
@@ -299,7 +302,7 @@ class BaseModel(torch.nn.Module):
return out
- def load_model_weights(self, sd, unet_prefix=""):
+ def load_model_weights(self, sd, unet_prefix="", assign=False):
to_load = {}
keys = list(sd.keys())
for k in keys:
@@ -307,7 +310,7 @@ class BaseModel(torch.nn.Module):
to_load[k[len(unet_prefix):]] = sd.pop(k)
to_load = self.model_config.process_unet_state_dict(to_load)
- m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
+ m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign)
if len(m) > 0:
logging.warning("unet missing: {}".format(m))
@@ -322,7 +325,7 @@ class BaseModel(torch.nn.Module):
def process_latent_out(self, latent):
return self.latent_format.process_out(latent)
- def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
+ def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
extra_sds = []
if clip_state_dict is not None:
extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict))
@@ -330,10 +333,7 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict))
if clip_vision_state_dict is not None:
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
-
- unet_state_dict = self.diffusion_model.state_dict()
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
-
if self.model_type == ModelType.V_PREDICTION:
unet_state_dict["v_pred"] = torch.tensor([])
@@ -776,8 +776,8 @@ class StableAudio1(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
- def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
- sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
+ def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
+ sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()}
for k in d:
s = d[k]
@@ -1541,6 +1541,47 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
+class ACEStep15(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ device = kwargs["device"]
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
+ if cross_attn is not None:
+ out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
+
+ refer_audio = kwargs.get("reference_audio_timbre_latents", None)
+ if refer_audio is None or len(refer_audio) == 0:
+ refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
+ 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
+ -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
+ 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
+ 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
+ -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
+ 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
+ -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
+ 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
+ 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
+ -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
+ -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
+ 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
+ else:
+ refer_audio = refer_audio[-1]
+ out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
+
+ audio_codes = kwargs.get("audio_codes", None)
+ if audio_codes is not None:
+ out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
+
+ return out
+
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 8cea16e50..e8ad725df 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
+ if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
+ dit_config = {}
+ dit_config["audio_model"] = "ace1.5"
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 9d39be7b2..b6291f340 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,13 +19,21 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args, PerformanceFeature
+from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
+import threading
import torch
import sys
import platform
import weakref
import gc
import os
+from contextlib import nullcontext
+import comfy.memory_management
+import comfy.utils
+import comfy.quant_ops
+
+import comfy_aimdo.torch
+import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -578,9 +586,15 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS:
+ import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
+ def get_free_ram():
+ return comfy.windows.get_free_ram()
+else:
+ def get_free_ram():
+ return psutil.virtual_memory().available
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@@ -592,7 +606,7 @@ def extra_reserved_memory():
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
-def free_memory(memory_required, device, keep_loaded=[]):
+def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
cleanup_models_gc()
unloaded_model = []
can_unload = []
@@ -607,15 +621,23 @@ def free_memory(memory_required, device, keep_loaded=[]):
for x in sorted(can_unload):
i = x[-1]
- memory_to_free = None
+ memory_to_free = 1e32
+ ram_to_free = 1e32
if not DISABLE_SMART_MEMORY:
- free_mem = get_free_memory(device)
- if free_mem > memory_required:
- break
- memory_to_free = memory_required - free_mem
- logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
- if current_loaded_models[i].model_unload(memory_to_free):
+ memory_to_free = memory_required - get_free_memory(device)
+ ram_to_free = ram_required - get_free_ram()
+
+ if current_loaded_models[i].model.is_dynamic() and for_dynamic:
+ #don't actually unload dynamic models for the sake of other dynamic models
+ #as that works on-demand.
+ memory_required -= current_loaded_models[i].model.loaded_size()
+ memory_to_free = 0
+ if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
+ logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i)
+ if ram_to_free > 0:
+ logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
+ current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
@@ -629,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
-def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -650,7 +672,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
models_to_load = []
+ free_for_dynamic=True
for x in models:
+ if not x.is_dynamic():
+ free_for_dynamic = False
loaded_model = LoadedModel(x)
try:
loaded_model_index = current_loaded_models.index(loaded_model)
@@ -676,19 +701,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
+
total_memory_required = {}
+ total_ram_required = {}
for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
+ #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
+ #want to do.
+ #FIXME: This should subtract off the to_load current pin consumption.
+ total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
for device in total_memory_required:
if device != torch.device("cpu"):
- free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
+ free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
for device in total_memory_required:
if device != torch.device("cpu"):
free_mem = get_free_memory(device)
if free_mem < minimum_memory_required:
- models_l = free_memory(minimum_memory_required, device)
+ models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic)
logging.info("{} models unloaded.".format(len(models_l)))
for loaded_model in models_to_load:
@@ -716,6 +747,26 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_loaded_models.insert(0, loaded_model)
return
+def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
+ with torch.inference_mode():
+ load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
+ soft_empty_cache()
+
+def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+ #Deliberately load models outside of the Aimdo mempool so they can be retained accross
+ #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
+ #thread local. So exploit that to escape context
+ if enables_dynamic_vram():
+ t = threading.Thread(
+ target=load_models_gpu_thread,
+ args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
+ )
+ t.start()
+ t.join()
+ else:
+ load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
+ minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
+
def load_model_gpu(model):
return load_models_gpu([model])
@@ -732,6 +783,9 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
+
+ reset_cast_buffers()
+
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
@@ -749,6 +803,11 @@ def cleanup_models_gc():
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
+def archive_model_dtypes(model):
+ for name, module in model.named_modules():
+ for param_name, param in module.named_parameters(recurse=False):
+ setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
+
def cleanup_models():
to_delete = []
@@ -792,7 +851,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
- if mem_dev > mem_cpu and model_size < mem_dev:
+ if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
return torch_dev
else:
return cpu_dev
@@ -1051,6 +1110,51 @@ def current_stream(device):
return None
stream_counters = {}
+
+STREAM_CAST_BUFFERS = {}
+LARGEST_CASTED_WEIGHT = (None, 0)
+
+def get_cast_buffer(offload_stream, device, size, ref):
+ global LARGEST_CASTED_WEIGHT
+
+ if offload_stream is not None:
+ wf_context = offload_stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(offload_stream)
+ else:
+ wf_context = nullcontext()
+
+ cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None)
+ if cast_buffer is None or cast_buffer.numel() < size:
+ if ref is LARGEST_CASTED_WEIGHT[0]:
+ #If there is one giant weight we do not want both streams to
+ #allocate a buffer for it. It's up to the caster to get the other
+ #offload stream in this corner case
+ return None
+ if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2):
+ #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now
+ synchronize()
+ del STREAM_CAST_BUFFERS[offload_stream]
+ del cast_buffer
+ #FIXME: This doesn't work in Aimdo because mempool cant clear cache
+ soft_empty_cache()
+ with wf_context:
+ cast_buffer = torch.empty((size), dtype=torch.int8, device=device)
+ STREAM_CAST_BUFFERS[offload_stream] = cast_buffer
+
+ if size > LARGEST_CASTED_WEIGHT[1]:
+ LARGEST_CASTED_WEIGHT = (ref, size)
+
+ return cast_buffer
+
+def reset_cast_buffers():
+ global LARGEST_CASTED_WEIGHT
+ LARGEST_CASTED_WEIGHT = (None, 0)
+ for offload_stream in STREAM_CAST_BUFFERS:
+ offload_stream.synchronize()
+ STREAM_CAST_BUFFERS.clear()
+ soft_empty_cache()
+
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS == 0:
@@ -1093,7 +1197,62 @@ def sync_stream(device, stream):
return
current_stream(device).wait_stream(stream)
-def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
+
+def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
+ wf_context = nullcontext()
+ if stream is not None:
+ wf_context = stream
+ if hasattr(wf_context, "as_context"):
+ wf_context = wf_context.as_context(stream)
+
+ dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
+ with wf_context:
+ for tensor in tensors:
+ dest_view = dest_views.pop(0)
+ if tensor is None:
+ continue
+ dest_view.copy_(tensor, non_blocking=non_blocking)
+
+
+def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
+ if hasattr(weight, "_v"):
+ #Unexpected usage patterns. There is no reason these don't work but they
+ #have no testing and no callers do this.
+ assert r is None
+ assert stream is None
+
+ cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
+
+ if dtype is None:
+ dtype = weight._model_dtype
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+
+ signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
+ if signature is not None:
+ raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ weight._v_signature = signature
+ #Send it over
+ v_tensor.copy_(weight, non_blocking=non_blocking)
+ #always take a deep copy even if _v is good, as we have no reasonable point to unpin
+ #a non comfy weight
+ r.copy_(v_tensor)
+ comfy_aimdo.model_vbar.vbar_unpin(weight._v)
+ return r
+
+ if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
+ #Offloaded casting could skip this, however it would make the quantizations
+ #inconsistent between loaded and offloaded weights. So force the double casting
+ #that would happen in regular flow to make offload deterministic.
+ cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
+ cast_buffer.copy_(weight, non_blocking=non_blocking)
+ weight = cast_buffer
+ r.copy_(weight, non_blocking=non_blocking)
+
+ return r
+
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@@ -1112,10 +1271,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
- r = torch.empty_like(weight, dtype=dtype, device=device)
+ if r is None:
+ r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:
- r = torch.empty_like(weight, dtype=dtype, device=device)
+ if r is None:
+ r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
@@ -1135,14 +1296,14 @@ if not args.disable_pinned_memory:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
-PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
+PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def discard_cuda_async_error():
try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
- torch.cuda.synchronize()
+ synchronize()
except torch.AcceleratorError:
#Dump it! We already know about it from the synchronous return
pass
@@ -1546,6 +1707,12 @@ def lora_compute_dtype(device):
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
+def synchronize():
+ if is_intel_xpu():
+ torch.xpu.synchronize()
+ elif torch.cuda.is_available():
+ torch.cuda.synchronize()
+
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
@@ -1557,6 +1724,7 @@ def soft_empty_cache(force=False):
elif is_mlu():
torch.mlu.empty_cache()
elif torch.cuda.is_available():
+ torch.cuda.synchronize()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@@ -1568,9 +1736,6 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary()
return ""
-#TODO: might be cleaner to put this somewhere else
-import threading
-
class InterruptProcessingException(Exception):
pass
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index f6b80a40f..d888dbcfb 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
-
-def string_to_seed(data):
- crc = 0xFFFFFFFF
- for byte in data:
- if isinstance(byte, str):
- byte = ord(byte)
- crc ^= byte
- for _ in range(8):
- if crc & 1:
- crc = (crc >> 1) ^ 0xEDB88320
- else:
- crc >>= 1
- return crc ^ 0xFFFFFFFF
+import comfy_aimdo.model_vbar
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@@ -123,6 +111,10 @@ def move_weight_functions(m, device):
memory += f.move_to(device=device)
return memory
+def string_to_seed(data):
+ logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils")
+ return comfy.utils.string_to_seed(data)
+
class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
@@ -169,6 +161,11 @@ def get_key_weight(model, key):
return weight, set_func, convert_func
+def key_param_name_to_key(key, param):
+ if len(key) == 0:
+ return param
+ return "{}.{}".format(key, param)
+
class AutoPatcherEjector:
def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False):
self.model = model
@@ -212,6 +209,27 @@ class MemoryCounter:
def decrement(self, used: int):
self.value -= used
+CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0)
+
+class LazyCastingParam(torch.nn.Parameter):
+ def __new__(cls, model, key, tensor):
+ return super().__new__(cls, tensor)
+
+ def __init__(self, model, key, tensor):
+ self.model = model
+ self.key = key
+
+ @property
+ def device(self):
+ return CustomTorchDevice
+
+ #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is
+ #then just a short lived thing in the safetensors serialization logic inside its big for loop over
+ #all weights getting garbage collected per-weight
+ def to(self, *args, **kwargs):
+ return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu")
+
+
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
@@ -269,6 +287,9 @@ class ModelPatcher:
if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0
+ def is_dynamic(self):
+ return False
+
def model_size(self):
if self.size > 0:
return self.size
@@ -284,6 +305,9 @@ class ModelPatcher:
def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter
+ def get_free_memory(self, device):
+ return comfy.model_management.get_free_memory(device)
+
def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {}
@@ -611,14 +635,14 @@ class ModelPatcher:
sd.pop(k)
return sd
- def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
- if key not in self.patches:
- return
-
+ def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False):
weight, set_func, convert_func = get_key_weight(self.model, key)
+ if key not in self.patches:
+ return weight
+
inplace_update = self.weight_inplace_update or inplace_update
- if key not in self.backup:
+ if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
@@ -631,13 +655,15 @@ class ModelPatcher:
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
if set_func is None:
- out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
- if inplace_update:
+ out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
+ if return_weight:
+ return out_weight
+ elif inplace_update:
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
comfy.utils.set_attr_param(self.model, key, out_weight)
else:
- set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
+ return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight)
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
@@ -654,7 +680,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
- def _load_list(self):
+ def _load_list(self, prio_comfy_cast_weights=False):
loading = []
for n, m in self.model.named_modules():
params = []
@@ -681,7 +707,8 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
- loading.append((module_offload_mem, module_mem, n, m, params))
+ prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
+ loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@@ -773,7 +800,7 @@ class ModelPatcher:
continue
for param in params:
- key = "{}.{}".format(n, param)
+ key = key_param_name_to_key(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
@@ -789,7 +816,7 @@ class ModelPatcher:
n = x[1]
params = x[3]
for param in params:
- self.pin_weight_to_device("{}.{}".format(n, param))
+ self.pin_weight_to_device(key_param_name_to_key(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
@@ -895,7 +922,7 @@ class ModelPatcher:
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
move_weight = True
for param in params:
- key = "{}.{}".format(n, param)
+ key = key_param_name_to_key(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
@@ -946,7 +973,7 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
for param in params:
- self.pin_weight_to_device("{}.{}".format(n, param))
+ self.pin_weight_to_device(key_param_name_to_key(n, param))
self.model.model_lowvram = True
@@ -984,6 +1011,9 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
+ def partially_unload_ram(self, ram_to_unload):
+ pass
+
def detach(self, unpatch_all=True):
self.eject_model()
self.model_patches_to(self.offload_device)
@@ -1317,10 +1347,10 @@ class ModelPatcher:
key, original_weights=original_weights)
del original_weights[key]
if set_func is None:
- out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key))
+ out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
comfy.utils.copy_to_param(self.model, key, out_weight)
else:
- set_func(out_weight, inplace_update=True, seed=string_to_seed(key))
+ set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key))
if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed:
# TODO: disable caching if not enough system RAM to do so
target_device = self.offload_device
@@ -1355,7 +1385,249 @@ class ModelPatcher:
self.unpatch_hooks()
self.clear_cached_hook_weights()
+ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None):
+ unet_state_dict = self.model.diffusion_model.state_dict()
+ for k, v in unet_state_dict.items():
+ op_keys = k.rsplit('.', 1)
+ if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
+ continue
+ try:
+ op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
+ except:
+ continue
+ if not op or not hasattr(op, "comfy_cast_weights") or \
+ (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True):
+ continue
+ key = "diffusion_model." + k
+ unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key))
+ return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict)
+
def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False)
+class ModelPatcherDynamic(ModelPatcher):
+
+ def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False):
+ if load_device is not None and comfy.model_management.is_device_cpu(load_device):
+ #reroute to default MP for CPUs
+ return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update)
+ return super().__new__(cls)
+
+ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
+ super().__init__(model, load_device, offload_device, size, weight_inplace_update)
+ #this is now way more dynamic and we dont support the same base model for both Dynamic
+ #and non-dynamic patchers.
+ if hasattr(self.model, "model_loaded_weight_memory"):
+ del self.model.model_loaded_weight_memory
+ if not hasattr(self.model, "dynamic_vbars"):
+ self.model.dynamic_vbars = {}
+ assert load_device is not None
+
+ def is_dynamic(self):
+ return True
+
+ def _vbar_get(self, create=False):
+ if self.load_device == torch.device("cpu"):
+ return None
+ vbar = self.model.dynamic_vbars.get(self.load_device, None)
+ if create and vbar is None:
+ # x10. We dont know what model defined type casts we have in the vbar, but virtual address
+ # space is pretty free. This will cover someone casting an entire model from FP4 to FP32
+ # with some left over.
+ vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index)
+ self.model.dynamic_vbars[self.load_device] = vbar
+ return vbar
+
+ def loaded_size(self):
+ vbar = self._vbar_get()
+ if vbar is None:
+ return 0
+ return vbar.loaded_size()
+
+ def get_free_memory(self, device):
+ #NOTE: on high condition / batch counts, estimate should have already vacated
+ #all non-dynamic models so this is safe even if its not 100% true that this
+ #would all be avaiable for inference use.
+ return comfy.model_management.get_total_memory(device) - self.model_size()
+
+ #Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
+
+ def pin_weight_to_device(self, key):
+ raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading")
+
+ def unpin_weight(self, key):
+ raise RuntimeError("unpin_weight invalid for dymamic weight loading")
+
+ def unpin_all_weights(self):
+ self.partially_unload_ram(1e32)
+
+ def memory_required(self, input_shape):
+ #Pad this significantly. We are trying to get away from precise estimates. This
+ #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you
+ #use all ModelPatcherDynamic this is ignored and its all done dynamically.
+ return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
+
+
+ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
+
+ #Force patching doesn't make sense in Dynamic loading, as you dont know what does and
+ #doesn't need to be forced at this stage. The only thing you could do would be patch
+ #it all on CPU which consumes huge RAM.
+ assert not force_patch_weights
+
+ #Full load doesn't make sense as we dont actually have any loader capability here and
+ #now.
+ assert not full_load
+
+ assert device_to == self.load_device
+
+ num_patches = 0
+ allocated_size = 0
+
+ with self.use_ejected():
+ self.unpatch_hooks()
+
+ vbar = self._vbar_get(create=True)
+ if vbar is not None:
+ vbar.prioritize()
+
+ #We have way more tools for acceleration on comfy weight offloading, so always
+ #prioritize the non-comfy weights (note the order reverse).
+ loading = self._load_list(prio_comfy_cast_weights=True)
+ loading.sort(reverse=True)
+
+ for x in loading:
+ _, _, _, n, m, params = x
+
+ def set_dirty(item, dirty):
+ if dirty or not hasattr(item, "_v_signature"):
+ item._v_signature = None
+
+ def setup_param(self, m, n, param_key):
+ nonlocal num_patches
+ key = key_param_name_to_key(n, param_key)
+
+ weight_function = []
+
+ weight, _, _ = get_key_weight(self.model, key)
+ if weight is None:
+ return 0
+ if key in self.patches:
+ setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
+ num_patches += 1
+ else:
+ setattr(m, param_key + "_lowvram_function", None)
+
+ if key in self.weight_wrapper_patches:
+ weight_function.extend(self.weight_wrapper_patches[key])
+ setattr(m, param_key + "_function", weight_function)
+ geometry = weight
+ if not isinstance(weight, QuantizedTensor):
+ model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
+ weight._model_dtype = model_dtype
+ geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
+ return comfy.memory_management.vram_aligned_size(geometry)
+
+ if hasattr(m, "comfy_cast_weights"):
+ m.comfy_cast_weights = True
+ m.pin_failed = False
+ m.seed_key = n
+ set_dirty(m, dirty)
+
+ v_weight_size = 0
+ v_weight_size += setup_param(self, m, n, "weight")
+ v_weight_size += setup_param(self, m, n, "bias")
+
+ if vbar is not None and not hasattr(m, "_v"):
+ m._v = vbar.alloc(v_weight_size)
+ allocated_size += v_weight_size
+
+ else:
+ for param in params:
+ key = key_param_name_to_key(n, param)
+ weight, _, _ = get_key_weight(self.model, key)
+ weight.seed_key = key
+ set_dirty(weight, dirty)
+ geometry = weight
+ model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
+ geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
+ weight_size = geometry.numel() * geometry.element_size()
+ if vbar is not None and not hasattr(weight, "_v"):
+ weight._v = vbar.alloc(weight_size)
+ weight._model_dtype = model_dtype
+ allocated_size += weight_size
+
+ logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
+
+ self.model.device = device_to
+ self.model.current_weight_patches_uuid = self.patches_uuid
+
+ for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
+ #These are all super dangerous. Who knows what the custom nodes actually do here...
+ callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
+
+ self.apply_hooks(self.forced_hooks, force_apply=True)
+
+ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
+ assert not force_patch_weights #See above
+ assert self.load_device != torch.device("cpu")
+
+ vbar = self._vbar_get()
+ return 0 if vbar is None else vbar.free_memory(memory_to_free)
+
+ def partially_unload_ram(self, ram_to_unload):
+ loading = self._load_list(prio_comfy_cast_weights=True)
+ for x in loading:
+ _, _, _, _, m, _ = x
+ ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
+ if ram_to_unload <= 0:
+ return
+
+ def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
+ #This isn't used by the core at all and can only be to load a model out of
+ #the control of proper model_managment. If you are a custom node author reading
+ #this, the correct pattern is to call load_models_gpu() to get a proper
+ #managed load of your model.
+ assert not load_weights
+ return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights)
+
+ def unpatch_model(self, device_to=None, unpatch_weights=True):
+ super().unpatch_model(device_to=None, unpatch_weights=False)
+
+ if unpatch_weights:
+ self.partially_unload_ram(1e32)
+ self.partially_unload(None, 1e32)
+
+ def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
+ assert not force_patch_weights #See above
+ with self.use_ejected(skip_and_inject_on_exit_only=True):
+ dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid)
+
+ self.unpatch_model(self.offload_device, unpatch_weights=False)
+ self.patch_model(load_weights=False)
+
+ try:
+ self.load(device_to, dirty=dirty)
+ except Exception as e:
+ self.detach()
+ raise e
+ #ModelPatcher::partially_load returns a number on what got loaded but
+ #nothing in core uses this and we have no data in the Dynamic world. Hit
+ #the custom node devs with a None rather than a 0 that would mislead any
+ #logic they might have.
+ return None
+
+ def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
+ assert False #Should be unreachable - we dont ever cache in the new implementation
+
+ def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter):
+ if key not in combined_patches:
+ return
+
+ raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup")
+
+ def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
+ pass
+
+CoreModelPatcher = ModelPatcher
diff --git a/comfy/ops.py b/comfy/ops.py
index e406ba7ed..53c5e4dc3 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -19,10 +19,16 @@
import torch
import logging
import comfy.model_management
-from comfy.cli_args import args, PerformanceFeature
+from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
import comfy.rmsnorm
import json
+import comfy.memory_management
+import comfy.pinned_memory
+import comfy.utils
+
+import comfy_aimdo.model_vbar
+import comfy_aimdo.torch
def run_every_op():
if torch.compiler.is_compiling():
@@ -72,7 +78,115 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
-def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
+def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
+ offload_stream = None
+ xfer_dest = None
+ cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
+
+ signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
+ if signature is not None:
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
+ resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+
+ if not resident:
+ cast_dest = None
+
+ xfer_source = [ s.weight, s.bias ]
+
+ pin = comfy.pinned_memory.get_pin(s)
+ if pin is not None:
+ xfer_source = [ pin ]
+
+ for data, geometry in zip([ s.weight, s.bias ], cast_geometry):
+ if data is None:
+ continue
+ if data.dtype != geometry.dtype:
+ cast_dest = xfer_dest
+ if cast_dest is None:
+ cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
+ xfer_dest = None
+ break
+
+ dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ if xfer_dest is None and offload_stream is not None:
+ xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
+ if xfer_dest is None:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
+ if xfer_dest is None:
+ xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
+ offload_stream = None
+
+ if signature is None and pin is None:
+ comfy.pinned_memory.pin_memory(s)
+ pin = comfy.pinned_memory.get_pin(s)
+ else:
+ pin = None
+
+ if pin is not None:
+ comfy.model_management.cast_to_gathered(xfer_source, pin)
+ xfer_source = [ pin ]
+ #send it over
+ comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ if cast_dest is not None:
+ for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
+ comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
+ if post_cast is not None:
+ post_cast.copy_(pre_cast)
+ xfer_dest = cast_dest
+
+ params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ weight = params[0]
+ bias = params[1]
+
+ def post_cast(s, param_key, x, dtype, resident, update_weight):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ fns = getattr(s, param_key + "_function", [])
+
+ orig = x
+
+ def to_dequant(tensor, dtype):
+ tensor = tensor.to(dtype=dtype)
+ if isinstance(tensor, QuantizedTensor):
+ tensor = tensor.dequantize()
+ return tensor
+
+ if orig.dtype != dtype or len(fns) > 0:
+ x = to_dequant(x, dtype)
+ if not resident and lowvram_fn is not None:
+ x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
+ #FIXME: this is not accurate, we need to be sensitive to the compute dtype
+ x = lowvram_fn(x)
+ if (isinstance(orig, QuantizedTensor) and
+ (orig.dtype == dtype and len(fns) == 0 or update_weight)):
+ seed = comfy.utils.string_to_seed(s.seed_key)
+ y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
+ if orig.dtype == dtype and len(fns) == 0:
+ #The layer actually wants our freshly saved QT
+ x = y
+ else:
+ y = x
+ if update_weight:
+ orig.copy_(y)
+ for f in fns:
+ x = f(x)
+ return x
+
+ update_weight = signature is not None
+
+ weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
+ if s.bias is not None:
+ bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
+ s._v_signature=signature
+
+ #FIXME: weird offload return protocol
+ return weight, bias, (offload_stream, device if signature is not None else None, None)
+
+
+def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
+ non_blocking = comfy.model_management.device_supports_non_blocking(device)
+
+ if hasattr(s, "_v"):
+ return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
+
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
- non_blocking = comfy.model_management.device_supports_non_blocking(device)
+ bias = None
+ weight = None
+
+ if offload_stream is not None and not args.cuda_malloc:
+ cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ])
+ cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
+ #The streams can be uneven in buffer capability and reject us. Retry to get the other stream
+ if cast_buffer is None:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
+ params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
+ weight = params[0]
+ bias = params[1]
weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
- weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
+ weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight)
- bias = None
if s.bias is not None:
- bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
+ bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias)
comfy.model_management.sync_stream(device, offload_stream)
@@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
weight_a = weight
if s.bias is not None:
+ bias = bias.to(dtype=bias_dtype)
for f in s.bias_function:
bias = f(bias)
@@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
+ device=None
+ #FIXME: This is not good RTTI
+ if not isinstance(weight_a, torch.Tensor):
+ comfy_aimdo.model_vbar.vbar_unpin(s._v)
+ device = weight_a
if os is None:
return
- if weight_a is not None:
- device = weight_a.device
- else:
- if bias_a is None:
- return
- device = bias_a.device
+ if device is None:
+ if weight_a is not None:
+ device = weight_a.device
+ else:
+ if bias_a is None:
+ return
+ device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
@@ -149,6 +286,57 @@ class CastWeightBiasOp:
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
+
+ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
+ if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
+ super().__init__(in_features, out_features, bias, device, dtype)
+ return
+
+ # Issue is with `torch.empty` still reserving the full memory for the layer.
+ # Windows doesn't over-commit memory so without this, We are momentarily commit
+ # charged for the weight even though we might zero-copy it when we load the
+ # state dict. If the commit charge exceeds the ceiling we can destabilize the
+ # system.
+ torch.nn.Module.__init__(self)
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = None
+ self.bias = None
+ self.comfy_need_lazy_init_bias=bias
+ self.weight_comfy_model_dtype = dtype
+ self.bias_comfy_model_dtype = dtype
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys, error_msgs):
+
+ if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
+ return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs)
+ assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
+ prefix_len = len(prefix)
+ for k,v in state_dict.items():
+ if k[prefix_len:] == "weight":
+ if not assign_to_params_buffers:
+ v = v.clone()
+ self.weight = torch.nn.Parameter(v, requires_grad=False)
+ elif k[prefix_len:] == "bias" and v is not None:
+ if not assign_to_params_buffers:
+ v = v.clone()
+ self.bias = torch.nn.Parameter(v, requires_grad=False)
+ else:
+ unexpected_keys.append(k)
+
+ #Reconcile default construction of the weight if its missing.
+ if self.weight is None:
+ v = torch.zeros(self.in_features, self.out_features)
+ self.weight = torch.nn.Parameter(v, requires_grad=False)
+ missing_keys.append(prefix+"weight")
+ if self.bias is None and self.comfy_need_lazy_init_bias:
+ v = torch.zeros(self.out_features,)
+ self.bias = torch.nn.Parameter(v, requires_grad=False)
+ missing_keys.append(prefix+"bias")
+
+
def reset_parameters(self):
return None
@@ -655,8 +843,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
- def forward_comfy_cast_weights(self, input):
- weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ def forward_comfy_cast_weights(self, input, compute_dtype=None):
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -666,6 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
input_shape = input.shape
reshaped_3d = False
+ #If cast needs to apply lora, it should be done in the compute dtype
+ compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
@@ -684,7 +874,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
- output = self.forward_comfy_cast_weights(input)
+
+ output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py
new file mode 100644
index 000000000..8acc327a7
--- /dev/null
+++ b/comfy/pinned_memory.py
@@ -0,0 +1,29 @@
+import torch
+import comfy.model_management
+import comfy.memory_management
+
+from comfy.cli_args import args
+
+def get_pin(module):
+ return getattr(module, "_pin", None)
+
+def pin_memory(module):
+ if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
+ return
+ #FIXME: This is a RAM cache trigger event
+ size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
+ pin = torch.empty((size,), dtype=torch.uint8)
+ if comfy.model_management.pin_memory(pin):
+ module._pin = pin
+ else:
+ module.pin_failed = True
+ return False
+ return True
+
+def unpin_memory(module):
+ if get_pin(module) is None:
+ return 0
+ size = module._pin.numel() * module._pin.element_size()
+ comfy.model_management.unpin_memory(module._pin)
+ del module._pin
+ return size
diff --git a/comfy/samplers.py b/comfy/samplers.py
index 1989ef107..8b9782956 100755
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch
from functools import partial
import collections
-from comfy import model_management
import math
import logging
import comfy.sampler_helpers
@@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
- free_memory = model_management.get_free_memory(x_in.device)
+ free_memory = model.current_patcher.get_free_memory(x_in.device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
diff --git a/comfy/sd.py b/comfy/sd.py
index f627f7d55..bc63d6ced 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
+import comfy.text_encoders.ace15
import comfy.model_patcher
import comfy.lora
@@ -228,8 +229,10 @@ class CLIP:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
+ model_management.archive_model_dtypes(self.cond_stage_model)
+
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
- self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
+ self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
@@ -389,8 +392,18 @@ class CLIP:
def load_sd(self, sd, full_model=False):
if full_model:
- return self.cond_stage_model.load_state_dict(sd, strict=False)
+ return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
else:
+ can_assign = self.patcher.is_dynamic()
+ self.cond_stage_model.can_assign_sd = can_assign
+
+ # The CLIP models are a pretty complex web of wrappers and its
+ # a bit of an API change to plumb this all the way through.
+ # So spray paint the model with this flag that the loading
+ # nn.Module can then inspect for itself.
+ for m in self.cond_stage_model.modules():
+ m.can_assign_sd = can_assign
+
return self.cond_stage_model.load_sd(sd)
def get_sd(self):
@@ -440,6 +453,8 @@ class VAE:
self.extra_1d_channel = None
self.crop_input = True
+ self.audio_sample_rate = 44100
+
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -537,14 +552,27 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
- self.first_stage_model = AudioOobleckVAE()
+ config = {}
+ param_key = None
+ self.upscale_ratio = 2048
+ self.downscale_ratio = 2048
+ if "decoder.layers.2.layers.1.weight_v" in sd:
+ param_key = "decoder.layers.2.layers.1.weight_v"
+ if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
+ param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
+ if param_key is not None:
+ if sd[param_key].shape[-1] == 12:
+ config["strides"] = [2, 4, 4, 6, 10]
+ self.audio_sample_rate = 48000
+ self.upscale_ratio = 1920
+ self.downscale_ratio = 1920
+
+ self.first_stage_model = AudioOobleckVAE(**config)
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
- self.upscale_ratio = 2048
- self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
@@ -765,12 +793,7 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- m, u = self.first_stage_model.load_state_dict(sd, strict=False)
- if len(m) > 0:
- logging.warning("Missing VAE keys {}".format(m))
-
- if len(u) > 0:
- logging.debug("Leftover VAE keys {}".format(u))
+ model_management.archive_model_dtypes(self.first_stage_model)
if device is None:
device = model_management.vae_device()
@@ -782,7 +805,18 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
- self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
+ mp = comfy.model_patcher.CoreModelPatcher
+ if self.disable_offload:
+ mp = comfy.model_patcher.ModelPatcher
+ self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device)
+
+ m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
+ if len(m) > 0:
+ logging.warning("Missing VAE keys {}".format(m))
+
+ if len(u) > 0:
+ logging.debug("Leftover VAE keys {}".format(u))
+
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
@@ -838,7 +872,7 @@ class VAE:
/ 3.0)
return output
- def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
+ def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
@@ -897,7 +931,7 @@ class VAE:
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = model_management.get_free_memory(self.device)
+ free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
@@ -971,7 +1005,7 @@ class VAE:
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
- free_memory = model_management.get_free_memory(self.device)
+ free_memory = self.patcher.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
samples = None
@@ -1409,6 +1443,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
+ elif clip_type == CLIPType.ACE:
+ te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
+ if TEModel.QWEN3_4B in te_models:
+ model_type = "qwen3_4b"
+ else:
+ model_type = "qwen3_2b"
+ clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -1432,7 +1474,7 @@ def load_gligen(ckpt_path):
model = gligen.load_gligen(data)
if model_management.should_use_fp16():
model = model.half()
- return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
+ return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device())
def model_detection_error_hint(path, state_dict):
filename = os.path.basename(path)
@@ -1520,7 +1562,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
- model.load_model_weights(sd, diffusion_model_prefix)
+ model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
+ model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
@@ -1563,7 +1606,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
logging.debug("left over keys: {}".format(left_over))
if output_model:
- model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
if inital_load_device != torch.device("cpu"):
logging.info("loaded diffusion model directly to GPU")
model_management.load_models_gpu([model_patcher], force_full_load=True)
@@ -1655,13 +1697,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "")
- model = model.to(offload_device)
- model.load_model_weights(new_sd, "")
+ model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
+ if not model_management.is_device_cpu(offload_device):
+ model.to(offload_device)
+ model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in diffusion model: {}".format(left_over))
- return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
-
+ return model_patcher
def load_diffusion_model(unet_path, model_options={}):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
@@ -1692,9 +1735,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if metadata is None:
metadata = {}
- model_management.load_models_gpu(load_models, force_patch_weights=True)
+ model_management.load_models_gpu(load_models)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
- sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
+ sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index d4f22120b..4c817d468 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
+ elif isinstance(layer_idx, list):
+ self.layer = layer_idx
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
@@ -297,7 +299,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
- return self.transformer.load_state_dict(sd, strict=False)
+ return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def parse_parentheses(string):
result = []
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index d25271d6e..77264ed28 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
+import comfy.text_encoders.ace15
from . import supported_models_base
from . import latent_formats
@@ -1596,6 +1597,46 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
+class ACEStep15(supported_models_base.BASE):
+ unet_config = {
+ "audio_model": "ace1.5",
+ }
+
+ unet_extra_config = {
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 3.0,
+ }
+
+ latent_format = comfy.latent_formats.ACEAudio15
+
+ memory_usage_factor = 4.7
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.ACEStep15(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
+ detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
+ if "dtype_llama" in detect_2b:
+ detect = detect_2b
+ detect["lm_model"] = "qwen3_2b"
+ elif "dtype_llama" in detect_4b:
+ detect = detect_4b
+ detect["lm_model"] = "qwen3_4b"
+
+ return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
+
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
new file mode 100644
index 000000000..fce2b67ce
--- /dev/null
+++ b/comfy/text_encoders/ace15.py
@@ -0,0 +1,249 @@
+from .anima import Qwen3Tokenizer
+import comfy.text_encoders.llama
+from comfy import sd1_clip
+import torch
+import math
+import comfy.utils
+
+
+def sample_manual_loop_no_classes(
+ model,
+ ids=None,
+ paddings=[],
+ execution_dtype=None,
+ cfg_scale: float = 2.0,
+ temperature: float = 0.85,
+ top_p: float = 0.9,
+ top_k: int = None,
+ seed: int = 1,
+ min_tokens: int = 1,
+ max_new_tokens: int = 2048,
+ audio_start_id: int = 151669, # The cutoff ID for audio codes
+ audio_end_id: int = 215669,
+ eos_token_id: int = 151645,
+):
+ device = model.execution_device
+
+ if execution_dtype is None:
+ if comfy.model_management.should_use_bf16(device):
+ execution_dtype = torch.bfloat16
+ else:
+ execution_dtype = torch.float32
+
+ embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
+ for i, t in enumerate(paddings):
+ attention_mask[i, :t] = 0
+ attention_mask[i, t:] = 1
+
+ output_audio_codes = []
+ past_key_values = []
+ generator = torch.Generator(device=device)
+ generator.manual_seed(seed)
+ model_config = model.transformer.model.config
+
+ for x in range(model_config.num_hidden_layers):
+ past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+
+ progress_bar = comfy.utils.ProgressBar(max_new_tokens)
+
+ for step in range(max_new_tokens):
+ outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
+ next_token_logits = model.transformer.logits(outputs[0])[:, -1]
+ past_key_values = outputs[2]
+
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+
+ if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ eos_score = cfg_logits[:, eos_token_id].clone()
+
+ remove_logit_value = torch.finfo(cfg_logits.dtype).min
+ # Only generate audio tokens
+ cfg_logits[:, :audio_start_id] = remove_logit_value
+ cfg_logits[:, audio_end_id:] = remove_logit_value
+
+ if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ cfg_logits[:, eos_token_id] = eos_score
+
+ if top_k is not None and top_k > 0:
+ top_k_vals, _ = torch.topk(cfg_logits, top_k)
+ min_val = top_k_vals[..., -1, None]
+ cfg_logits[cfg_logits < min_val] = remove_logit_value
+
+ if top_p is not None and top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ cfg_logits[indices_to_remove] = remove_logit_value
+
+ if temperature > 0:
+ cfg_logits = cfg_logits / temperature
+ next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
+ else:
+ next_token = torch.argmax(cfg_logits, dim=-1)
+
+ token = next_token.item()
+
+ if token == eos_token_id:
+ break
+
+ embed, _, _, _ = model.process_tokens([[token]], device)
+ embeds = embed.repeat(2, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+
+ output_audio_codes.append(token - audio_start_id)
+ progress_bar.update_absolute(step)
+
+ return output_audio_codes
+
+
+def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
+ cfg_scale = 2.0
+
+ positive = [[token for token, _ in inner_list] for inner_list in positive]
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ positive = positive[0]
+ negative = negative[0]
+
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
+
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ paddings = [pos_pad, neg_pad]
+ return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+
+
+class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
+
+ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
+ out = {}
+ lyrics = kwargs.get("lyrics", "")
+ bpm = kwargs.get("bpm", 120)
+ duration = kwargs.get("duration", 120)
+ keyscale = kwargs.get("keyscale", "C major")
+ timesignature = kwargs.get("timesignature", 2)
+ language = kwargs.get("language", "en")
+ seed = kwargs.get("seed", 0)
+
+ duration = math.ceil(duration)
+ meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n"
+
+ meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
+ out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
+ out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
+
+ out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
+ out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
+ out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
+ return out
+
+
+class Qwen3_06BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class ACE15TEModel(torch.nn.Module):
+ def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
+ super().__init__()
+ if dtype_llama is None:
+ dtype_llama = dtype
+
+ model = None
+ self.constant = 0.4375
+ if lm_model == "qwen3_4b":
+ model = Qwen3_4B_ACE15
+ self.constant = 0.5625
+ elif lm_model == "qwen3_2b":
+ model = Qwen3_2B_ACE15
+
+ self.lm_model = lm_model
+ self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
+ if model is not None:
+ setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
+
+ self.dtypes = set([dtype, dtype_llama])
+
+ def encode_token_weights(self, token_weight_pairs):
+ token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
+ token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
+
+ self.qwen3_06b.set_clip_options({"layer": None})
+ base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
+ self.qwen3_06b.set_clip_options({"layer": [0]})
+ lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
+
+ lm_metadata = token_weight_pairs["lm_metadata"]
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
+
+ return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
+
+ def set_clip_options(self, options):
+ self.qwen3_06b.set_clip_options(options)
+ lm_model = getattr(self, self.lm_model, None)
+ if lm_model is not None:
+ lm_model.set_clip_options(options)
+
+ def reset_clip_options(self):
+ self.qwen3_06b.reset_clip_options()
+ lm_model = getattr(self, self.lm_model, None)
+ if lm_model is not None:
+ lm_model.reset_clip_options()
+
+ def load_sd(self, sd):
+ if "model.layers.0.post_attention_layernorm.weight" in sd:
+ shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
+ if shape[0] == 1024:
+ return self.qwen3_06b.load_sd(sd)
+ else:
+ return getattr(self, self.lm_model).load_sd(sd)
+
+ def memory_estimation_function(self, token_weight_pairs, device=None):
+ lm_metadata = token_weight_pairs["lm_metadata"]
+ constant = self.constant
+ if comfy.model_management.should_use_bf16(device):
+ constant *= 0.5
+
+ token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
+ num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
+ num_tokens += lm_metadata['min_tokens']
+ return num_tokens * constant * 1024 * 1024
+
+def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
+ class ACE15TEModel_(ACE15TEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
+ return ACE15TEModel_
diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py
index 41f95bcb6..b6f58cb25 100644
--- a/comfy/text_encoders/anima.py
+++ b/comfy/text_encoders/anima.py
@@ -8,7 +8,7 @@ import torch
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py
index f67a5f805..1ae398789 100644
--- a/comfy/text_encoders/flux.py
+++ b/comfy/text_encoders/flux.py
@@ -118,7 +118,7 @@ class MistralTokenizerClass:
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
- super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
+ super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
@@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
class Qwen3Tokenizer8B(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data)
class KleinTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"):
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index 3080a3e09..3afd094d1 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -1,11 +1,12 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
-from typing import Optional, Any
+from typing import Optional, Any, Tuple
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
+import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
@@ -32,6 +33,7 @@ class Llama2Config:
k_norm = None
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Mistral3Small24BConfig:
@@ -54,6 +56,7 @@ class Mistral3Small24BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Qwen25_3BConfig:
@@ -76,6 +79,7 @@ class Qwen25_3BConfig:
k_norm = None
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Qwen3_06BConfig:
@@ -98,6 +102,76 @@ class Qwen3_06BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
+
+@dataclass
+class Qwen3_06B_ACE15_Config:
+ vocab_size: int = 151669
+ hidden_size: int = 1024
+ intermediate_size: int = 3072
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 32768
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
+
+@dataclass
+class Qwen3_2B_ACE15_lm_Config:
+ vocab_size: int = 217204
+ hidden_size: int = 2048
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
+
+@dataclass
+class Qwen3_4B_ACE15_lm_Config:
+ vocab_size: int = 217204
+ hidden_size: int = 2560
+ intermediate_size: int = 9728
+ num_hidden_layers: int = 36
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Qwen3_4BConfig:
@@ -120,6 +194,7 @@ class Qwen3_4BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Qwen3_8BConfig:
@@ -142,6 +217,7 @@ class Qwen3_8BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Ovis25_2BConfig:
@@ -164,6 +240,7 @@ class Ovis25_2BConfig:
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Qwen25_7BVLI_Config:
@@ -186,6 +263,7 @@ class Qwen25_7BVLI_Config:
k_norm = None
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Gemma2_2B_Config:
@@ -209,6 +287,7 @@ class Gemma2_2B_Config:
sliding_attention = None
rope_scale = None
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Gemma3_4B_Config:
@@ -232,6 +311,7 @@ class Gemma3_4B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
+ lm_head: bool = False
@dataclass
class Gemma3_12B_Config:
@@ -255,6 +335,7 @@ class Gemma3_12B_Config:
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
+ lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
@@ -356,6 +437,7 @@ class Attention(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states)
@@ -373,11 +455,30 @@ class Attention(nn.Module):
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
+ present_key_value = None
+ if past_key_value is not None:
+ index = 0
+ num_tokens = xk.shape[2]
+ if len(past_key_value) > 0:
+ past_key, past_value, index = past_key_value
+ if past_key.shape[2] >= (index + num_tokens):
+ past_key[:, :, index:index + xk.shape[2]] = xk
+ past_value[:, :, index:index + xv.shape[2]] = xv
+ xk = past_key[:, :, :index + xk.shape[2]]
+ xv = past_value[:, :, :index + xv.shape[2]]
+ present_key_value = (past_key, past_value, index + num_tokens)
+ else:
+ xk = torch.cat((past_key[:, :, :index], xk), dim=2)
+ xv = torch.cat((past_value[:, :, :index], xv), dim=2)
+ present_key_value = (xk, xv, index + num_tokens)
+ else:
+ present_key_value = (xk, xv, index + num_tokens)
+
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
- return self.o_proj(output)
+ return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@@ -408,15 +509,17 @@ class TransformerBlock(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
# Self Attention
residual = x
x = self.input_layernorm(x)
- x = self.self_attn(
+ x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
+ past_key_value=past_key_value,
)
x = residual + x
@@ -426,7 +529,7 @@ class TransformerBlock(nn.Module):
x = self.mlp(x)
x = residual + x
- return x
+ return x, present_key_value
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
@@ -451,6 +554,7 @@ class TransformerBlockGemma2(nn.Module):
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
+ past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
@@ -468,11 +572,12 @@ class TransformerBlockGemma2(nn.Module):
# Self Attention
residual = x
x = self.input_layernorm(x)
- x = self.self_attn(
+ x, present_key_value = self.self_attn(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
+ past_key_value=past_key_value,
)
x = self.post_attention_layernorm(x)
@@ -485,7 +590,7 @@ class TransformerBlockGemma2(nn.Module):
x = self.post_feedforward_layernorm(x)
x = residual + x
- return x
+ return x, present_key_value
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
@@ -516,9 +621,10 @@ class Llama2_(nn.Module):
else:
self.norm = None
- # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
+ if config.lm_head:
+ self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
- def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
+ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
if embeds is not None:
x = embeds
else:
@@ -527,8 +633,13 @@ class Llama2_(nn.Module):
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
+ seq_len = x.shape[1]
+ past_len = 0
+ if past_key_values is not None and len(past_key_values) > 0:
+ past_len = past_key_values[0][2]
+
if position_ids is None:
- position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
+ position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
@@ -539,14 +650,16 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
- mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
+ mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
+
+ if seq_len > 1:
+ causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
+ if mask is not None:
+ mask += causal_mask
+ else:
+ mask = causal_mask
- causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
- if mask is not None:
- mask += causal_mask
- else:
- mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
@@ -562,16 +675,27 @@ class Llama2_(nn.Module):
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
+ next_key_values = []
for i, layer in enumerate(self.layers):
if all_intermediate is not None:
if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
- x = layer(
+
+ past_kv = None
+ if past_key_values is not None:
+ past_kv = past_key_values[i] if len(past_key_values) > 0 else []
+
+ x, current_kv = layer(
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
+ past_key_value=past_kv,
)
+
+ if current_kv is not None:
+ next_key_values.append(current_kv)
+
if i == intermediate_output:
intermediate = x.clone()
@@ -588,7 +712,10 @@ class Llama2_(nn.Module):
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate)
- return x, intermediate
+ if len(next_key_values) > 0:
+ return x, intermediate, next_key_values
+ else:
+ return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module):
@@ -635,6 +762,21 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
+class BaseQwen3:
+ def logits(self, x):
+ input = x[:, -1:]
+ module = self.model.embed_tokens
+
+ offload_stream = None
+ if module.comfy_cast_weights:
+ weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
+ else:
+ weight = self.model.embed_tokens.weight.to(x)
+
+ x = torch.nn.functional.linear(input, weight, None)
+
+ comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
+ return x
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
@@ -663,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_06B(BaseLlama, torch.nn.Module):
+class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -672,7 +814,25 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_4B(BaseLlama, torch.nn.Module):
+class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_06B_ACE15_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_2B_ACE15_lm_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -681,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_8B(BaseLlama, torch.nn.Module):
+class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_4B_ACE15_lm_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index e49161964..26573fb12 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module):
for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]:
component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)}
if component_sd:
- missing, unexpected = component.load_state_dict(component_sd, strict=False)
+ missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False))
missing_all.extend([f"{prefix}{k}" for k in missing])
unexpected_all.extend([f"{prefix}{k}" for k in unexpected])
diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py
index ad41bfb1e..33b7cf594 100644
--- a/comfy/text_encoders/z_image.py
+++ b/comfy/text_encoders/z_image.py
@@ -6,7 +6,7 @@ import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
diff --git a/comfy/utils.py b/comfy/utils.py
index d97d753e6..1337e2205 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -28,9 +28,11 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
-from comfy.cli_args import args
+from comfy.cli_args import args, enables_dynamic_vram
import json
import time
+import mmap
+import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@@ -56,21 +58,74 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
+# Current as of safetensors 0.7.0
+_TYPES = {
+ "F64": torch.float64,
+ "F32": torch.float32,
+ "F16": torch.float16,
+ "BF16": torch.bfloat16,
+ "I64": torch.int64,
+ "I32": torch.int32,
+ "I16": torch.int16,
+ "I8": torch.int8,
+ "U8": torch.uint8,
+ "BOOL": torch.bool,
+ "F8_E4M3": torch.float8_e4m3fn,
+ "F8_E5M2": torch.float8_e5m2,
+ "C64": torch.complex64,
+
+ "U64": torch.uint64,
+ "U32": torch.uint32,
+ "U16": torch.uint16,
+}
+
+def load_safetensors(ckpt):
+ f = open(ckpt, "rb")
+ mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
+ mv = memoryview(mapping)
+
+ header_size = struct.unpack(" 0:
message = e.args[0]
@@ -1308,3 +1363,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata
+
+def string_to_seed(data):
+ crc = 0xFFFFFFFF
+ for byte in data:
+ if isinstance(byte, str):
+ byte = ord(byte)
+ crc ^= byte
+ for _ in range(8):
+ if crc & 1:
+ crc = (crc >> 1) ^ 0xEDB88320
+ else:
+ crc >>= 1
+ return crc ^ 0xFFFFFFFF
diff --git a/comfy/windows.py b/comfy/windows.py
new file mode 100644
index 000000000..213dc481d
--- /dev/null
+++ b/comfy/windows.py
@@ -0,0 +1,52 @@
+import ctypes
+import logging
+import psutil
+from ctypes import wintypes
+
+import comfy_aimdo.control
+
+psapi = ctypes.WinDLL("psapi")
+kernel32 = ctypes.WinDLL("kernel32")
+
+class PERFORMANCE_INFORMATION(ctypes.Structure):
+ _fields_ = [
+ ("cb", wintypes.DWORD),
+ ("CommitTotal", ctypes.c_size_t),
+ ("CommitLimit", ctypes.c_size_t),
+ ("CommitPeak", ctypes.c_size_t),
+ ("PhysicalTotal", ctypes.c_size_t),
+ ("PhysicalAvailable", ctypes.c_size_t),
+ ("SystemCache", ctypes.c_size_t),
+ ("KernelTotal", ctypes.c_size_t),
+ ("KernelPaged", ctypes.c_size_t),
+ ("KernelNonpaged", ctypes.c_size_t),
+ ("PageSize", ctypes.c_size_t),
+ ("HandleCount", wintypes.DWORD),
+ ("ProcessCount", wintypes.DWORD),
+ ("ThreadCount", wintypes.DWORD),
+ ]
+
+def get_free_ram():
+ #Windows is way too conservative and chalks recently used uncommitted model RAM
+ #as "in-use". So, calculate free RAM for the sake of general use as the greater of:
+ #
+ #1: What psutil says
+ #2: Total Memory - (Committed Memory - VRAM in use)
+ #
+ #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
+ #commit charge for all VRAM used just incase it wants to page it all out. This just
+ #isn't realistic so "overcommit" on our calculations by just subtracting it off.
+
+ pi = PERFORMANCE_INFORMATION()
+ pi.cb = ctypes.sizeof(pi)
+
+ if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
+ logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
+ return psutil.virtual_memory().available
+
+ committed = pi.CommitTotal * pi.PageSize
+ total = pi.PhysicalTotal * pi.PageSize
+
+ return max(psutil.virtual_memory().available,
+ total - (committed - comfy_aimdo.control.get_total_vram_usage()))
+
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index b0fa14ff6..8542a1dbc 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
-from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
+from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@@ -105,6 +105,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
+ File3D = File3D
ComfyAPI = ComfyAPI_latest
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 78f77d4b2..93cf482ca 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
-from ._util import MESH, VOXEL, SVG as _SVG
+from ._util import MESH, VOXEL, SVG as _SVG, File3D
class FolderType(str, Enum):
@@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
+
+@comfytype(io_type="FILE_3D")
+class File3DAny(ComfyTypeIO):
+ """General 3D file type - accepts any supported 3D format."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_GLB")
+class File3DGLB(ComfyTypeIO):
+ """GLB format 3D file - binary glTF, best for web and cross-platform."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_GLTF")
+class File3DGLTF(ComfyTypeIO):
+ """GLTF format 3D file - JSON-based glTF with external resources."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_FBX")
+class File3DFBX(ComfyTypeIO):
+ """FBX format 3D file - best for game engines and animation."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_OBJ")
+class File3DOBJ(ComfyTypeIO):
+ """OBJ format 3D file - simple geometry format."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_STL")
+class File3DSTL(ComfyTypeIO):
+ """STL format 3D file - best for 3D printing."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_USDZ")
+class File3DUSDZ(ComfyTypeIO):
+ """USDZ format 3D file - Apple AR format."""
+ Type = File3D
+
+
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@@ -1248,6 +1291,7 @@ class Hidden(str, Enum):
class NodeInfoV1:
input: dict=None
input_order: dict[str, list[str]]=None
+ is_input_list: bool=None
output: list[str]=None
output_is_list: list[bool]=None
output_name: list[str]=None
@@ -1474,6 +1518,7 @@ class Schema:
info = NodeInfoV1(
input=input,
input_order={key: list(value.keys()) for (key, value) in input.items()},
+ is_input_list=self.is_input_list,
output=output,
output_is_list=output_is_list,
output_name=output_name,
@@ -2035,6 +2080,13 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
+ "File3DAny",
+ "File3DGLB",
+ "File3DGLTF",
+ "File3DFBX",
+ "File3DOBJ",
+ "File3DSTL",
+ "File3DUSDZ",
"Hooks",
"HookKeyframes",
"TimestepsRange",
diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py
index 6313eb01b..115baf392 100644
--- a/comfy_api/latest/_util/__init__.py
+++ b/comfy_api/latest/_util/__init__.py
@@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
-from .geometry_types import VOXEL, MESH
+from .geometry_types import VOXEL, MESH, File3D
from .image_types import SVG
__all__ = [
@@ -9,5 +9,6 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
+ "File3D",
"SVG",
]
diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py
index 385122778..b586fceb3 100644
--- a/comfy_api/latest/_util/geometry_types.py
+++ b/comfy_api/latest/_util/geometry_types.py
@@ -1,3 +1,8 @@
+import shutil
+from io import BytesIO
+from pathlib import Path
+from typing import IO
+
import torch
@@ -10,3 +15,75 @@ class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces
+
+
+class File3D:
+ """Class representing a 3D file from a file path or binary stream.
+
+ Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
+ """
+
+ def __init__(self, source: str | IO[bytes], file_format: str = ""):
+ self._source = source
+ self._format = file_format or self._infer_format()
+
+ def _infer_format(self) -> str:
+ if isinstance(self._source, str):
+ return Path(self._source).suffix.lstrip(".").lower()
+ return ""
+
+ @property
+ def format(self) -> str:
+ return self._format
+
+ @format.setter
+ def format(self, value: str) -> None:
+ self._format = value.lstrip(".").lower() if value else ""
+
+ @property
+ def is_disk_backed(self) -> bool:
+ return isinstance(self._source, str)
+
+ def get_source(self) -> str | IO[bytes]:
+ if isinstance(self._source, str):
+ return self._source
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ return self._source
+
+ def get_data(self) -> BytesIO:
+ if isinstance(self._source, str):
+ with open(self._source, "rb") as f:
+ result = BytesIO(f.read())
+ return result
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ if isinstance(self._source, BytesIO):
+ return self._source
+ return BytesIO(self._source.read())
+
+ def save_to(self, path: str) -> str:
+ dest = Path(path)
+ dest.parent.mkdir(parents=True, exist_ok=True)
+
+ if isinstance(self._source, str):
+ if Path(self._source).resolve() != dest.resolve():
+ shutil.copy2(self._source, dest)
+ else:
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ with open(dest, "wb") as f:
+ f.write(self._source.read())
+ return str(dest)
+
+ def get_bytes(self) -> bytes:
+ if isinstance(self._source, str):
+ return Path(self._source).read_bytes()
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ return self._source.read()
+
+ def __repr__(self) -> str:
+ if isinstance(self._source, str):
+ return f"File3D(source={self._source!r}, format={self._format!r})"
+ return f"File3D(, format={self._format!r})"
diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py
new file mode 100644
index 000000000..b23c5d9eb
--- /dev/null
+++ b/comfy_api_nodes/apis/hitpaw.py
@@ -0,0 +1,51 @@
+from typing import TypedDict
+
+from pydantic import BaseModel, Field
+
+
+class InputVideoModel(TypedDict):
+ model: str
+ resolution: str
+
+
+class ImageEnhanceTaskCreateRequest(BaseModel):
+ model_name: str = Field(...)
+ img_url: str = Field(...)
+ extension: str = Field(".png")
+ exif: bool = Field(False)
+ DPI: int | None = Field(None)
+
+
+class VideoEnhanceTaskCreateRequest(BaseModel):
+ video_url: str = Field(...)
+ extension: str = Field(".mp4")
+ model_name: str | None = Field(...)
+ resolution: list[int] = Field(..., description="Target resolution [width, height]")
+ original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
+
+
+class TaskCreateDataResponse(BaseModel):
+ job_id: str = Field(...)
+ consume_coins: int | None = Field(None)
+
+
+class TaskStatusPollRequest(BaseModel):
+ job_id: str = Field(...)
+
+
+class TaskCreateResponse(BaseModel):
+ code: int = Field(...)
+ message: str = Field(...)
+ data: TaskCreateDataResponse | None = Field(None)
+
+
+class TaskStatusDataResponse(BaseModel):
+ job_id: str = Field(...)
+ status: str = Field(...)
+ res_url: str = Field("")
+
+
+class TaskStatusResponse(BaseModel):
+ code: int = Field(...)
+ message: str = Field(...)
+ data: TaskStatusDataResponse = Field(...)
diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py
index be46d0d58..7d72e6e91 100644
--- a/comfy_api_nodes/apis/meshy.py
+++ b/comfy_api_nodes/apis/meshy.py
@@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel):
class MeshyModelsUrls(BaseModel):
glb: str = Field("")
+ fbx: str = Field("")
+ usdz: str = Field("")
+ obj: str = Field("")
class MeshyRiggedModelsUrls(BaseModel):
rigged_character_glb_url: str = Field("")
+ rigged_character_fbx_url: str = Field("")
class MeshyAnimatedModelsUrls(BaseModel):
animation_glb_url: str = Field("")
+ animation_fbx_url: str = Field("")
class MeshyResultTextureUrls(BaseModel):
diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py
new file mode 100644
index 000000000..488080a74
--- /dev/null
+++ b/comfy_api_nodes/nodes_hitpaw.py
@@ -0,0 +1,342 @@
+import math
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.hitpaw import (
+ ImageEnhanceTaskCreateRequest,
+ InputVideoModel,
+ TaskCreateDataResponse,
+ TaskCreateResponse,
+ TaskStatusPollRequest,
+ TaskStatusResponse,
+ VideoEnhanceTaskCreateRequest,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ download_url_to_video_output,
+ downscale_image_tensor,
+ get_image_dimensions,
+ poll_op,
+ sync_op,
+ upload_image_to_comfyapi,
+ upload_video_to_comfyapi,
+ validate_video_duration,
+)
+
+VIDEO_MODELS_MODELS_MAP = {
+ "Portrait Restore Model (1x)": "portrait_restore_1x",
+ "Portrait Restore Model (2x)": "portrait_restore_2x",
+ "General Restore Model (1x)": "general_restore_1x",
+ "General Restore Model (2x)": "general_restore_2x",
+ "General Restore Model (4x)": "general_restore_4x",
+ "Ultra HD Model (2x)": "ultrahd_restore_2x",
+ "Generative Model (1x)": "generative_1x",
+}
+
+# Resolution name to target dimension (shorter side) in pixels
+RESOLUTION_TARGET_MAP = {
+ "720p": 720,
+ "1080p": 1080,
+ "2K/QHD": 1440,
+ "4K/UHD": 2160,
+ "8K": 4320,
+}
+
+# Square (1:1) resolutions use standard square dimensions
+RESOLUTION_SQUARE_MAP = {
+ "720p": 720,
+ "1080p": 1080,
+ "2K/QHD": 1440,
+ "4K/UHD": 2048, # DCI 4K square
+ "8K": 4096, # DCI 8K square
+}
+
+# Models with limited resolution support (no 8K)
+LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
+
+# Resolution options for different model types
+RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
+RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
+
+# Maximum output resolution in pixels
+MAX_PIXELS_GENERATIVE = 32_000_000
+MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
+
+
+class HitPawGeneralImageEnhance(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="HitPawGeneralImageEnhance",
+ display_name="HitPaw General Image Enhance",
+ category="api node/image/HitPaw",
+ description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
+ f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
+ inputs=[
+ IO.Combo.Input("model", options=["generative_portrait", "generative"]),
+ IO.Image.Input("image"),
+ IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
+ IO.Boolean.Input(
+ "auto_downscale",
+ default=False,
+ tooltip="Automatically downscale input image if output would exceed the limit.",
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model"]),
+ expr="""
+ (
+ $prices := {
+ "generative_portrait": {"min": 0.02, "max": 0.06},
+ "generative": {"min": 0.05, "max": 0.15}
+ };
+ $price := $lookup($prices, widgets.model);
+ {
+ "type": "range_usd",
+ "min_usd": $price.min,
+ "max_usd": $price.max
+ }
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: str,
+ image: Input.Image,
+ upscale_factor: int,
+ auto_downscale: bool,
+ ) -> IO.NodeOutput:
+ height, width = get_image_dimensions(image)
+ requested_scale = upscale_factor
+ output_pixels = height * width * requested_scale * requested_scale
+ if output_pixels > MAX_PIXELS_GENERATIVE:
+ if auto_downscale:
+ input_pixels = width * height
+ scale = 1
+ max_input_pixels = MAX_PIXELS_GENERATIVE
+
+ for candidate in [4, 2, 1]:
+ if candidate > requested_scale:
+ continue
+ scale_output_pixels = input_pixels * candidate * candidate
+ if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
+ scale = candidate
+ max_input_pixels = None
+ break
+ # Check if we can downscale input by at most 2x to fit
+ downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
+ if downscale_ratio <= 2.0:
+ scale = candidate
+ max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
+ break
+
+ if max_input_pixels is not None:
+ image = downscale_image_tensor(image, total_pixels=max_input_pixels)
+ upscale_factor = scale
+ else:
+ output_width = width * requested_scale
+ output_height = height * requested_scale
+ raise ValueError(
+ f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
+ f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
+ f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
+ )
+
+ initial_res = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
+ response_model=TaskCreateResponse,
+ data=ImageEnhanceTaskCreateRequest(
+ model_name=f"{model}_{upscale_factor}x",
+ img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
+ ),
+ wait_label="Creating task",
+ final_label_on_success="Task created",
+ )
+ if initial_res.code != 200:
+ raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
+ request_price = initial_res.data.consume_coins / 1000
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
+ data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda x: x.data.status,
+ price_extractor=lambda x: request_price,
+ poll_interval=10.0,
+ max_poll_attempts=480,
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
+
+
+class HitPawVideoEnhance(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ model_options = []
+ for model_name in VIDEO_MODELS_MODELS_MAP:
+ if model_name in LIMITED_RESOLUTION_MODELS:
+ resolutions = RESOLUTIONS_LIMITED
+ else:
+ resolutions = RESOLUTIONS_FULL
+ model_options.append(
+ IO.DynamicCombo.Option(
+ model_name,
+ [IO.Combo.Input("resolution", options=resolutions)],
+ )
+ )
+
+ return IO.Schema(
+ node_id="HitPawVideoEnhance",
+ display_name="HitPaw Video Enhance",
+ category="api node/video/HitPaw",
+ description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
+ "Prices shown are per second of video.",
+ inputs=[
+ IO.DynamicCombo.Input("model", options=model_options),
+ IO.Video.Input("video"),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
+ expr="""
+ (
+ $m := $lookup(widgets, "model");
+ $res := $lookup(widgets, "model.resolution");
+ $standard_model_prices := {
+ "original": {"min": 0.01, "max": 0.198},
+ "720p": {"min": 0.01, "max": 0.06},
+ "1080p": {"min": 0.015, "max": 0.09},
+ "2k/qhd": {"min": 0.02, "max": 0.117},
+ "4k/uhd": {"min": 0.025, "max": 0.152},
+ "8k": {"min": 0.033, "max": 0.198}
+ };
+ $ultra_hd_model_prices := {
+ "original": {"min": 0.015, "max": 0.264},
+ "720p": {"min": 0.015, "max": 0.092},
+ "1080p": {"min": 0.02, "max": 0.12},
+ "2k/qhd": {"min": 0.026, "max": 0.156},
+ "4k/uhd": {"min": 0.034, "max": 0.203},
+ "8k": {"min": 0.044, "max": 0.264}
+ };
+ $generative_model_prices := {
+ "original": {"min": 0.015, "max": 0.338},
+ "720p": {"min": 0.008, "max": 0.090},
+ "1080p": {"min": 0.05, "max": 0.15},
+ "2k/qhd": {"min": 0.038, "max": 0.225},
+ "4k/uhd": {"min": 0.056, "max": 0.338}
+ };
+ $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
+ $contains($m, "generative") ? $generative_model_prices :
+ $standard_model_prices;
+ $price := $lookup($prices, $res);
+ {
+ "type": "range_usd",
+ "min_usd": $price.min,
+ "max_usd": $price.max,
+ "format": {"approximate": true, "suffix": "/second"}
+ }
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: InputVideoModel,
+ video: Input.Video,
+ ) -> IO.NodeOutput:
+ validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
+ resolution = model["resolution"]
+ src_width, src_height = video.get_dimensions()
+
+ if resolution == "original":
+ output_width = src_width
+ output_height = src_height
+ else:
+ if src_width == src_height:
+ target_size = RESOLUTION_SQUARE_MAP[resolution]
+ if target_size < src_width:
+ raise ValueError(
+ f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
+ f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
+ )
+ output_width = target_size
+ output_height = target_size
+ else:
+ min_dimension = min(src_width, src_height)
+ target_size = RESOLUTION_TARGET_MAP[resolution]
+ if target_size < min_dimension:
+ raise ValueError(
+ f"Selected resolution {resolution} ({target_size}p) is smaller than "
+ f"the input video's shorter dimension ({min_dimension}p). "
+ f"Please select a higher resolution or 'original'."
+ )
+ if src_width > src_height:
+ output_height = target_size
+ output_width = int(target_size * (src_width / src_height))
+ else:
+ output_width = target_size
+ output_height = int(target_size * (src_height / src_width))
+ initial_res = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
+ response_model=TaskCreateResponse,
+ data=VideoEnhanceTaskCreateRequest(
+ video_url=await upload_video_to_comfyapi(cls, video),
+ resolution=[output_width, output_height],
+ original_resolution=[src_width, src_height],
+ model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
+ ),
+ wait_label="Creating task",
+ final_label_on_success="Task created",
+ )
+ request_price = initial_res.data.consume_coins / 1000
+ if initial_res.code != 200:
+ raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
+ data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda x: x.data.status,
+ price_extractor=lambda x: request_price,
+ poll_interval=10.0,
+ max_poll_attempts=320,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
+
+
+class HitPawExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ HitPawGeneralImageEnhance,
+ HitPawVideoEnhance,
+ ]
+
+
+async def comfy_entrypoint() -> HitPawExtension:
+ return HitPawExtension()
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index b3a736643..813a7c809 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,5 +1,3 @@
-import os
-
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
@@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_to_bytesio,
+ download_url_to_file_3d,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@@ -22,14 +20,13 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
validate_string,
)
-from folder_paths import get_output_directory
-def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
+def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
for i in response_objs:
- if i.Type.lower() == "glb":
+ if i.Type.lower() == file_type.lower():
return i
- raise ValueError("No GLB file found in response. Please report this to the developers.")
+ return None
class TencentTextToModelNode(IO.ComfyNode):
@@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode):
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ task_id = response.JobId
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
- data=To3DProTaskQueryRequest(JobId=response.JobId),
+ data=To3DProTaskQueryRequest(JobId=task_id),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- model_file = f"hunyuan_model_{response.JobId}.glb"
- await download_url_to_bytesio(
- get_glb_obj_from_response(result.ResultFile3Ds).Url,
- os.path.join(get_output_directory(), model_file),
+ glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
+ obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
+ file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
+ return IO.NodeOutput(
+ file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
)
- return IO.NodeOutput(model_file)
class TencentImageToModelNode(IO.ComfyNode):
@@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode):
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ task_id = response.JobId
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
- data=To3DProTaskQueryRequest(JobId=response.JobId),
+ data=To3DProTaskQueryRequest(JobId=task_id),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- model_file = f"hunyuan_model_{response.JobId}.glb"
- await download_url_to_bytesio(
- get_glb_obj_from_response(result.ResultFile3Ds).Url,
- os.path.join(get_output_directory(), model_file),
+ glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
+ obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
+ file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
+ return IO.NodeOutput(
+ file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
)
- return IO.NodeOutput(model_file)
class TencentHunyuan3DExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py
index 740607983..65f6f0d2d 100644
--- a/comfy_api_nodes/nodes_meshy.py
+++ b/comfy_api_nodes/nodes_meshy.py
@@ -1,5 +1,3 @@
-import os
-
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
@@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_to_bytesio,
+ download_url_to_file_3d,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
-from folder_paths import get_output_directory
class MeshyTextToModelNode(IO.ComfyNode):
@@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyRefineNode(IO.ComfyNode):
@@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode):
ai_model=model,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyImageToModelNode(IO.ComfyNode):
@@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyMultiImageToModelNode(IO.ComfyNode):
@@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyRigModelNode(IO.ComfyNode):
@@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode):
texture_image_url=texture_image_url,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"),
response_model=MeshyRiggedResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(
- result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id),
)
- return IO.NodeOutput(model_file, response.result)
class MeshyAnimateModelNode(IO.ComfyNode):
@@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode):
action_id=action_id,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"),
response_model=MeshyAnimationResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id),
+ )
class MeshyTextureNode(IO.ComfyNode):
@@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode):
image_style_url=image_style_url,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py
index 3ffdc8b90..f9cff121f 100644
--- a/comfy_api_nodes/nodes_rodin.py
+++ b/comfy_api_nodes/nodes_rodin.py
@@ -10,7 +10,6 @@ import folder_paths as comfy_paths
import os
import logging
import math
-from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image
@@ -28,8 +27,9 @@ from comfy_api_nodes.util import (
poll_op,
ApiEndpoint,
download_url_to_bytesio,
+ download_url_to_file_3d,
)
-from comfy_api.latest import ComfyExtension, IO
+from comfy_api.latest import ComfyExtension, IO, Types
COMMON_PARAMETERS = [
@@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
-def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
+def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
if not response.jobs:
return None
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
@@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D
)
-async def download_files(url_list, task_uuid: str):
+async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]:
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True)
model_file_path = None
+ file_3d = None
+
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
- if file_path.endswith(".glb"):
+ if i.name.lower().endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
- await download_url_to_bytesio(i.url, file_path)
- return model_file_path
+ file_3d = await download_url_to_file_3d(i.url, "glb")
+ # Save to disk for backward compatibility
+ with open(file_path, "wb") as f:
+ f.write(file_3d.get_bytes())
+ else:
+ await download_url_to_bytesio(i.url, file_path)
+
+ return model_file_path, file_3d
class Rodin3D_Regular(IO.ComfyNode):
@@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Detail(IO.ComfyNode):
@@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Smooth(IO.ComfyNode):
@@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Sketch(IO.ComfyNode):
@@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode):
optional=True,
),
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Gen2(IO.ComfyNode):
@@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode):
),
IO.Boolean.Input("TAPose", default=False),
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3DExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py
index 5abf27b4d..67c7f59fc 100644
--- a/comfy_api_nodes/nodes_tripo.py
+++ b/comfy_api_nodes/nodes_tripo.py
@@ -1,10 +1,6 @@
-import os
-from typing import Optional
-
-import torch
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.tripo import (
TripoAnimateRetargetRequest,
TripoAnimateRigRequest,
@@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_as_bytesio,
+ download_url_to_file_3d,
poll_op,
sync_op,
upload_images_to_comfyapi,
)
-from folder_paths import get_output_directory
def get_model_url_from_response(response: TripoTaskResponse) -> str:
@@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
async def poll_until_finished(
node_cls: type[IO.ComfyNode],
response: TripoTaskResponse,
- average_duration: Optional[int] = None,
+ average_duration: int | None = None,
) -> IO.NodeOutput:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
@@ -69,12 +64,8 @@ async def poll_until_finished(
)
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
- bytesio = await download_url_as_bytesio(url)
- # Save the downloaded model file
- model_file = f"tripo_model_{task_id}.glb"
- with open(os.path.join(get_output_directory(), model_file), "wb") as f:
- f.write(bytesio.getvalue())
- return IO.NodeOutput(model_file, task_id)
+ file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id)
+ return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb)
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
@@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- negative_prompt: Optional[str] = None,
+ negative_prompt: str | None = None,
model_version=None,
- style: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- image_seed: Optional[int] = None,
- model_seed: Optional[int] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ style: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ image_seed: int | None = None,
+ model_seed: int | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if not prompt:
@@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
- model_version: Optional[str] = None,
- style: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- model_seed: Optional[int] = None,
+ image: Input.Image,
+ model_version: str | None = None,
+ style: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ model_seed: int | None = None,
orientation=None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ texture_alignment: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if image is None:
@@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
- image_left: Optional[torch.Tensor] = None,
- image_back: Optional[torch.Tensor] = None,
- image_right: Optional[torch.Tensor] = None,
- model_version: Optional[str] = None,
- orientation: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- model_seed: Optional[int] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ image: Input.Image,
+ image_left: Input.Image | None = None,
+ image_back: Input.Image | None = None,
+ image_right: Input.Image | None = None,
+ model_version: str | None = None,
+ orientation: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ model_seed: int | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ texture_alignment: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
if image is None:
raise RuntimeError("front image for multiview is required")
@@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode):
async def execute(
cls,
model_task_id,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ texture_alignment: str | None = None,
) -> IO.NodeOutput:
response = await sync_op(
cls,
@@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode):
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode):
category="api node/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index c3c9ff4bf..18b020eef 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -28,6 +28,7 @@ from .conversions import (
from .download_helpers import (
download_url_as_bytesio,
download_url_to_bytesio,
+ download_url_to_file_3d,
download_url_to_image_tensor,
download_url_to_video_output,
)
@@ -69,6 +70,7 @@ __all__ = [
# Download helpers
"download_url_as_bytesio",
"download_url_to_bytesio",
+ "download_url_to_file_3d",
"download_url_to_image_tensor",
"download_url_to_video_output",
# Conversions
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 4668d14a9..78bcf1fa1 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -11,7 +11,8 @@ import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.latest import IO as COMFY_IO
-from comfy_api.latest import InputImpl
+from comfy_api.latest import InputImpl, Types
+from folder_paths import get_output_directory
from . import request_logger
from ._helpers import (
@@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str:
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
+
+
+async def download_url_to_file_3d(
+ url: str,
+ file_format: str,
+ *,
+ task_id: str | None = None,
+ timeout: float | None = None,
+ max_retries: int = 5,
+ cls: type[COMFY_IO.ComfyNode] = None,
+) -> Types.File3D:
+ """Downloads a 3D model file from a URL into memory as BytesIO.
+
+ If task_id is provided, also writes the file to disk in the output directory
+ for backward compatibility with the old save-to-disk behavior.
+ """
+ file_format = file_format.lstrip(".").lower()
+ data = BytesIO()
+ await download_url_to_bytesio(
+ url,
+ data,
+ timeout=timeout,
+ max_retries=max_retries,
+ cls=cls,
+ )
+
+ if task_id is not None:
+ # This is only for backward compatability with current behavior when every 3D node is output node
+ # All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results
+ output_dir = Path(get_output_directory())
+ output_path = output_dir / f"{task_id}.{file_format}"
+ output_path.write_bytes(data.getvalue())
+ data.seek(0)
+
+ return Types.File3D(source=data, file_format=file_format)
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 3153f2b98..83d936ce1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
*,
mime_type: str | None = None,
wait_label: str | None = "Uploading",
- total_pixels: int = 2048 * 2048,
+ total_pixels: int | None = 2048 * 2048,
) -> str:
"""Uploads a single image to ComfyUI API and returns its download URL."""
return (
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
index 1409233c9..376584e5c 100644
--- a/comfy_extras/nodes_ace.py
+++ b/comfy_extras/nodes_ace.py
@@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode):
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return io.NodeOutput(conditioning)
+class TextEncodeAceStepAudio15(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="TextEncodeAceStepAudio1.5",
+ category="conditioning",
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("tags", multiline=True, dynamic_prompts=True),
+ io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
+ io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
+ io.Int.Input("bpm", default=120, min=10, max=300),
+ io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
+ io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
+ io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
+ io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
+ ],
+ outputs=[io.Conditioning.Output()],
+ )
+
+ @classmethod
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
+ conditioning = clip.encode_from_tokens_scheduled(tokens)
+ return io.NodeOutput(conditioning)
+
class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyAceStepLatentAudio",
+ display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
@@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
return io.NodeOutput({"samples": latent, "type": "audio"})
+class EmptyAceStep15LatentAudio(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="EmptyAceStep1.5LatentAudio",
+ display_name="Empty Ace Step 1.5 Latent Audio",
+ category="latent/audio",
+ inputs=[
+ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
+ io.Int.Input(
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
+ ),
+ ],
+ outputs=[io.Latent.Output()],
+ )
+
+ @classmethod
+ def execute(cls, seconds, batch_size) -> io.NodeOutput:
+ length = round((seconds * 48000 / 1920))
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
+ return io.NodeOutput({"samples": latent, "type": "audio"})
+
+class ReferenceTimbreAudio(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ReferenceTimbreAudio",
+ category="advanced/conditioning/audio",
+ is_experimental=True,
+ description="This node sets the reference audio for timbre (for ace step 1.5)",
+ inputs=[
+ io.Conditioning.Input("conditioning"),
+ io.Latent.Input("latent", optional=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(),
+ ]
+ )
+
+ @classmethod
+ def execute(cls, conditioning, latent=None) -> io.NodeOutput:
+ if latent is not None:
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
+ return io.NodeOutput(conditioning)
+
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,
+ TextEncodeAceStepAudio15,
+ EmptyAceStep15LatentAudio,
+ ReferenceTimbreAudio,
]
async def comfy_entrypoint() -> AceExtension:
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 271b75fbd..bef723dce 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
- if 44100 != sample_rate:
- waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
+ vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
+ if vae_sample_rate != sample_rate:
+ waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
- return IO.NodeOutput({"samples":t})
+ return IO.NodeOutput({"samples": t})
encode = execute # TODO: remove
@@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode):
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
- return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
+ vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
+ return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
decode = execute # TODO: remove
diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py
index 5bb5df48e..eda1639ab 100644
--- a/comfy_extras/nodes_hunyuan3d.py
+++ b/comfy_extras/nodes_hunyuan3d.py
@@ -622,14 +622,20 @@ class SaveGLB(IO.ComfyNode):
category="3d",
is_output_node=True,
inputs=[
- IO.Mesh.Input("mesh"),
+ IO.MultiType.Input(
+ IO.Mesh.Input("mesh"),
+ types=[
+ IO.File3DGLB,
+ ],
+ tooltip="Mesh or GLB file to save",
+ ),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
@classmethod
- def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
+ def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
@@ -641,15 +647,26 @@ class SaveGLB(IO.ComfyNode):
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
- for i in range(mesh.vertices.shape[0]):
+ if isinstance(mesh, Types.File3D):
+ # Handle File3D input - save BytesIO data to output folder
f = f"{filename}_{counter:05}_.glb"
- save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
+ mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
- counter += 1
+ else:
+ # Handle Mesh input - save vertices and faces as GLB
+ for i in range(mesh.vertices.shape[0]):
+ f = f"{filename}_{counter:05}_.glb"
+ save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
+ results.append({
+ "filename": f,
+ "subfolder": subfolder,
+ "type": "output"
+ })
+ counter += 1
return IO.NodeOutput(ui={"3d": results})
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 4b8d950ae..f29510488 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -1,9 +1,10 @@
import nodes
import folder_paths
import os
+import uuid
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
+from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
from pathlib import Path
@@ -81,7 +82,19 @@ class Preview3D(IO.ComfyNode):
is_experimental=True,
is_output_node=True,
inputs=[
- IO.String.Input("model_file", default="", multiline=False),
+ IO.MultiType.Input(
+ IO.String.Input("model_file", default="", multiline=False),
+ types=[
+ IO.File3DGLB,
+ IO.File3DGLTF,
+ IO.File3DFBX,
+ IO.File3DOBJ,
+ IO.File3DSTL,
+ IO.File3DUSDZ,
+ IO.File3DAny,
+ ],
+ tooltip="3D model file or path string",
+ ),
IO.Load3DCamera.Input("camera_info", optional=True),
IO.Image.Input("bg_image", optional=True),
],
@@ -89,10 +102,15 @@ class Preview3D(IO.ComfyNode):
)
@classmethod
- def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
+ def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
+ if isinstance(model_file, Types.File3D):
+ filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
+ model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
+ else:
+ filename = model_file
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
- return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
+ return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
process = execute # TODO: remove
diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py
index 82c4754a3..176e6bc2f 100644
--- a/comfy_extras/nodes_model_patch.py
+++ b/comfy_extras/nodes_model_patch.py
@@ -267,9 +267,9 @@ class ModelPatchLoader:
device=comfy.model_management.unet_offload_device(),
operations=comfy.ops.manual_cast)
- model.load_state_dict(sd)
- model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
- return (model,)
+ model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
+ model.load_state_dict(sd, assign=model_patcher.is_dynamic())
+ return (model_patcher,)
class DiffSynthCnetPatch:
diff --git a/comfyui_version.py b/comfyui_version.py
index b1ebaa115..5d296cd1b 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.11.1"
+__version__ = "0.12.2"
diff --git a/cuda_malloc.py b/cuda_malloc.py
index ee2bc4b69..b2182df37 100644
--- a/cuda_malloc.py
+++ b/cuda_malloc.py
@@ -1,8 +1,10 @@
import os
import importlib.util
-from comfy.cli_args import args, PerformanceFeature
+from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import subprocess
+import comfy_aimdo.control
+
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
if os.name == 'nt':
@@ -85,8 +87,14 @@ if not args.cuda_malloc:
except:
pass
+if enables_dynamic_vram() and comfy_aimdo.control.init():
+ args.cuda_malloc = False
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
-if args.cuda_malloc and not args.disable_cuda_malloc:
+if args.disable_cuda_malloc:
+ args.cuda_malloc = False
+
+if args.cuda_malloc:
env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None)
if env_var is None:
env_var = "backend:cudaMallocAsync"
diff --git a/execution.py b/execution.py
index 4b4f63c80..3dbab82e6 100644
--- a/execution.py
+++ b/execution.py
@@ -9,9 +9,11 @@ import traceback
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union
import asyncio
+from contextlib import nullcontext
import torch
+import comfy.memory_management
import comfy.model_management
from latent_preview import set_preview_method
import nodes
@@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
- output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
+
+ #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
+ #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
+ #that we just want to cull out each model run.
+ allocator = comfy.memory_management.aimdo_allocator
+ with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
+ try:
+ output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
+ finally:
+ if allocator is not None:
+ comfy.model_management.reset_cast_buffers()
+ torch.cuda.synchronize()
+
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@@ -1000,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[
outputs = set()
for x in prompt:
if 'class_type' not in prompt[x]:
+ node_data = prompt[x]
+ node_title = node_data.get('_meta', {}).get('title')
error = {
- "type": "invalid_prompt",
- "message": "Cannot execute because a node is missing the class_type property.",
+ "type": "missing_node_type",
+ "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.",
"details": f"Node ID '#{x}'",
- "extra_info": {}
+ "extra_info": {
+ "node_id": x,
+ "class_type": None,
+ "node_title": node_title
+ }
}
return (False, error, [], {})
class_type = prompt[x]['class_type']
class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None)
if class_ is None:
+ node_data = prompt[x]
+ node_title = node_data.get('_meta', {}).get('title', class_type)
error = {
- "type": "invalid_prompt",
- "message": f"Cannot execute because node {class_type} does not exist.",
+ "type": "missing_node_type",
+ "message": f"Node '{node_title}' not found. The custom node may not be installed.",
"details": f"Node ID '#{x}'",
- "extra_info": {}
+ "extra_info": {
+ "node_id": x,
+ "class_type": class_type,
+ "node_title": node_title
+ }
}
return (False, error, [], {})
diff --git a/main.py b/main.py
index 37b06c1fa..92d705b4d 100644
--- a/main.py
+++ b/main.py
@@ -5,7 +5,7 @@ import os
import importlib.util
import folder_paths
import time
-from comfy.cli_args import args
+from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools
@@ -173,6 +173,7 @@ import gc
if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
+
import comfy.utils
import execution
@@ -184,6 +185,36 @@ import comfyui_version
import app.logger
import hook_breaker_ac10a0
+import comfy.memory_management
+import comfy.model_patcher
+
+import comfy_aimdo.control
+import comfy_aimdo.torch
+
+if enables_dynamic_vram():
+ if comfy.model_management.torch_version_numeric < (2, 8):
+ logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
+ comfy.memory_management.aimdo_allocator = None
+ elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
+ if args.verbose == 'DEBUG':
+ comfy_aimdo.control.set_log_debug()
+ elif args.verbose == 'CRITICAL':
+ comfy_aimdo.control.set_log_critical()
+ elif args.verbose == 'ERROR':
+ comfy_aimdo.control.set_log_error()
+ elif args.verbose == 'WARNING':
+ comfy_aimdo.control.set_log_warning()
+ else: #INFO
+ comfy_aimdo.control.set_log_info()
+
+ comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
+ comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
+ logging.info("DynamicVRAM support detected and enabled")
+ else:
+ logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
+ comfy.memory_management.aimdo_allocator = None
+
+
def cuda_malloc_warning():
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
diff --git a/nodes.py b/nodes.py
index 1cb43d9e2..e11a8ed80 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1001,7 +1001,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
diff --git a/pyproject.toml b/pyproject.toml
index 042f124e4..1ddcc3596 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.11.1"
+version = "0.12.2"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
diff --git a/requirements.txt b/requirements.txt
index 4ac94cb16..0c401873a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.37.11
-comfyui-workflow-templates==0.8.27
+comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.0
torch
torchsde
@@ -22,6 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
+comfy-aimdo>=0.1.7
requests
#non essential dependencies:
diff --git a/server.py b/server.py
index 2aee5cc06..2300393b2 100644
--- a/server.py
+++ b/server.py
@@ -656,6 +656,7 @@ class PromptServer():
info = {}
info['input'] = obj_class.INPUT_TYPES()
info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()}
+ info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False)
info['output'] = obj_class.RETURN_TYPES
info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES)
info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output']