diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 4b3a3798c..f59999af6 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -755,6 +755,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
+class ACEAudio15(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 1
+
class ChromaRadiance(LatentFormat):
latent_channels = 3
spacial_downscale_ratio = 1
diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py
new file mode 100644
index 000000000..d90549658
--- /dev/null
+++ b/comfy/ldm/ace/ace_step15.py
@@ -0,0 +1,1093 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import itertools
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+from comfy.ldm.flux.layers import timestep_embedding
+
+def get_layer_class(operations, layer_name):
+ if operations is not None and hasattr(operations, layer_name):
+ return getattr(operations, layer_name)
+ return getattr(nn, layer_name)
+
+class RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.dim = dim
+ self.base = base
+ self.max_position_embeddings = max_position_embeddings
+
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype)
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len, x.device, x.dtype)
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device),
+ self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device),
+ )
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+def apply_rotary_pos_emb(q, k, cos, sin):
+ cos = cos.unsqueeze(0).unsqueeze(0)
+ sin = sin.unsqueeze(0).unsqueeze(0)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+class MLP(nn.Module):
+ def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
+ self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device)
+ self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x):
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device)
+ self.act1 = nn.SiLU()
+ self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device)
+ self.in_channels = in_channels
+ self.act2 = nn.SiLU()
+ self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device)
+ self.scale = scale
+
+ def forward(self, t, dtype=None):
+ t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale)
+ temb = self.linear_1(t_freq.to(dtype=dtype))
+ temb = self.act1(temb)
+ temb = self.linear_2(temb)
+ timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1])
+ return temb, timestep_proj
+
+class AceStepAttention(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ rms_norm_eps=1e-6,
+ is_cross_attention=False,
+ sliding_window=None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.num_kv_heads = num_kv_heads
+ self.head_dim = head_dim
+ self.is_cross_attention = is_cross_attention
+ self.sliding_window = sliding_window
+
+ Linear = get_layer_class(operations, "Linear")
+
+ self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device)
+ self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device)
+
+ self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ def forward(
+ self,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ position_embeddings=None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ query_states = self.q_norm(query_states)
+ query_states = query_states.transpose(1, 2)
+
+ if self.is_cross_attention and encoder_hidden_states is not None:
+ bsz_enc, kv_len, _ = encoder_hidden_states.size()
+ key_states = self.k_proj(encoder_hidden_states)
+ value_states = self.v_proj(encoder_hidden_states)
+
+ key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+ value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim)
+
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ else:
+ kv_len = q_len
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+ value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim)
+
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ n_rep = self.num_heads // self.num_kv_heads
+ if n_rep > 1:
+ key_states = key_states.repeat_interleave(n_rep, dim=1)
+ value_states = value_states.repeat_interleave(n_rep, dim=1)
+
+ attn_bias = None
+ if self.sliding_window is not None and not self.is_cross_attention:
+ indices = torch.arange(q_len, device=query_states.device)
+ diff = indices.unsqueeze(1) - indices.unsqueeze(0)
+ in_window = torch.abs(diff) <= self.sliding_window
+
+ window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype)
+ min_value = torch.finfo(query_states.dtype).min
+ window_bias.masked_fill_(~in_window, min_value)
+
+ window_bias = window_bias.unsqueeze(0).unsqueeze(0)
+
+ if attn_bias is not None:
+ if attn_bias.dtype == torch.bool:
+ base_bias = torch.zeros_like(window_bias)
+ base_bias.masked_fill_(~attn_bias, min_value)
+ attn_bias = base_bias + window_bias
+ else:
+ attn_bias = attn_bias + window_bias
+ else:
+ attn_bias = window_bias
+
+ attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True)
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output
+
+class AceStepDiTLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ layer_type="full_attention",
+ sliding_window=128,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+
+ self_attn_window = sliding_window if layer_type == "sliding_attention" else None
+
+ self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.self_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=False, sliding_window=self_attn_window,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.cross_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=True, dtype=dtype, device=device, operations=operations
+ )
+
+ self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
+
+ self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states,
+ temb,
+ encoder_hidden_states,
+ position_embeddings,
+ attention_mask=None,
+ encoder_attention_mask=None
+ ):
+ modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1)
+
+ norm_hidden = self.self_attn_norm(hidden_states)
+ norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa
+
+ attn_out = self.self_attn(
+ norm_hidden,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+ hidden_states = hidden_states + attn_out * gate_msa
+
+ norm_hidden = self.cross_attn_norm(hidden_states)
+ attn_out = self.cross_attn(
+ norm_hidden,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask
+ )
+ hidden_states = hidden_states + attn_out
+
+ norm_hidden = self.mlp_norm(hidden_states)
+ norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa
+
+ mlp_out = self.mlp(norm_hidden)
+ hidden_states = hidden_states + mlp_out * c_gate_msa
+
+ return hidden_states
+
+class AceStepEncoderLayer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.self_attn = AceStepAttention(
+ hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps,
+ is_cross_attention=False, dtype=dtype, device=device, operations=operations
+ )
+ self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, hidden_states, position_embeddings, attention_mask=None):
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states = self.self_attn(
+ hidden_states=hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+ return hidden_states
+
+class AceStepLyricEncoder(nn.Module):
+ def __init__(
+ self,
+ text_hidden_dim,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, inputs_embeds, attention_mask=None):
+ hidden_states = self.embed_tokens(inputs_embeds)
+ seq_len = hidden_states.shape[1]
+ cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
+ position_embeddings = (cos, sin)
+
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=position_embeddings,
+ attention_mask=attention_mask
+ )
+
+ hidden_states = self.norm(hidden_states)
+ return hidden_states
+
+class AceStepTimbreEncoder(nn.Module):
+ def __init__(
+ self,
+ timbre_hidden_dim,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+ self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
+
+ def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask):
+ N, d = timbre_embs_packed.shape
+ device = timbre_embs_packed.device
+ B = N
+ counts = torch.bincount(refer_audio_order_mask, minlength=B)
+ max_count = counts.max().item()
+
+ sorted_indices = torch.argsort(
+ refer_audio_order_mask * N + torch.arange(N, device=device),
+ stable=True
+ )
+ sorted_batch_ids = refer_audio_order_mask[sorted_indices]
+
+ positions = torch.arange(N, device=device)
+ batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]])
+ positions_in_sorted = positions - batch_starts[sorted_batch_ids]
+
+ inverse_indices = torch.empty_like(sorted_indices)
+ inverse_indices[sorted_indices] = torch.arange(N, device=device)
+ positions_in_batch = positions_in_sorted[inverse_indices]
+
+ indices_2d = refer_audio_order_mask * max_count + positions_in_batch
+ one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype)
+
+ timbre_embs_flat = one_hot.t() @ timbre_embs_packed
+ timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d)
+
+ mask_flat = (one_hot.sum(dim=0) > 0).long()
+ new_mask = mask_flat.view(B, max_count)
+ return timbre_embs_unpack, new_mask
+
+ def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None):
+ hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed)
+ if hidden_states.dim() == 2:
+ hidden_states = hidden_states.unsqueeze(0)
+
+ seq_len = hidden_states.shape[1]
+ cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len)
+
+ for layer in self.layers:
+ hidden_states = layer(
+ hidden_states,
+ position_embeddings=(cos, sin),
+ attention_mask=attention_mask
+ )
+ hidden_states = self.norm(hidden_states)
+
+ flat_states = hidden_states[:, 0, :]
+ unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask)
+ return unpacked_embs, unpacked_mask
+
+
+def pack_sequences(hidden1, hidden2, mask1, mask2):
+ hidden_cat = torch.cat([hidden1, hidden2], dim=1)
+ B, L, D = hidden_cat.shape
+
+ if mask1 is not None and mask2 is not None:
+ mask_cat = torch.cat([mask1, mask2], dim=1)
+ sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True)
+ gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D)
+ hidden_sorted = torch.gather(hidden_cat, 1, gather_idx)
+ lengths = mask_cat.sum(dim=1)
+ new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
+ else:
+ new_mask = None
+ hidden_sorted = hidden_cat
+
+ return hidden_sorted, new_mask
+
+class AceStepConditionEncoder(nn.Module):
+ def __init__(
+ self,
+ text_hidden_dim,
+ timbre_hidden_dim,
+ hidden_size,
+ num_lyric_layers,
+ num_timbre_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device)
+
+ self.lyric_encoder = AceStepLyricEncoder(
+ text_hidden_dim=text_hidden_dim,
+ hidden_size=hidden_size,
+ num_layers=num_lyric_layers,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ intermediate_size=intermediate_size,
+ rms_norm_eps=rms_norm_eps,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ self.timbre_encoder = AceStepTimbreEncoder(
+ timbre_hidden_dim=timbre_hidden_dim,
+ hidden_size=hidden_size,
+ num_layers=num_timbre_layers,
+ num_heads=num_heads,
+ num_kv_heads=num_kv_heads,
+ head_dim=head_dim,
+ intermediate_size=intermediate_size,
+ rms_norm_eps=rms_norm_eps,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ def forward(
+ self,
+ text_hidden_states=None,
+ text_attention_mask=None,
+ lyric_hidden_states=None,
+ lyric_attention_mask=None,
+ refer_audio_acoustic_hidden_states_packed=None,
+ refer_audio_order_mask=None
+ ):
+ text_emb = self.text_projector(text_hidden_states)
+
+ lyric_emb = self.lyric_encoder(
+ inputs_embeds=lyric_hidden_states,
+ attention_mask=lyric_attention_mask
+ )
+
+ timbre_emb, timbre_mask = self.timbre_encoder(
+ refer_audio_acoustic_hidden_states_packed,
+ refer_audio_order_mask
+ )
+
+ merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask)
+ final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask)
+
+ return final_emb, final_mask
+
+# --------------------------------------------------------------------------------
+# Main Diffusion Model (DiT)
+# --------------------------------------------------------------------------------
+
+class AceStepDiTModel(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ hidden_size,
+ num_layers,
+ num_heads,
+ num_kv_heads,
+ head_dim,
+ intermediate_size,
+ patch_size,
+ audio_acoustic_hidden_dim,
+ layer_types=None,
+ sliding_window=128,
+ rms_norm_eps=1e-6,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.rotary_emb = RotaryEmbedding(
+ head_dim,
+ base=1000000.0,
+ dtype=dtype,
+ device=device,
+ operations=operations
+ )
+
+ Conv1d = get_layer_class(operations, "Conv1d")
+ ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d")
+ Linear = get_layer_class(operations, "Linear")
+
+ self.proj_in = nn.Sequential(
+ nn.Identity(),
+ Conv1d(
+ in_channels, hidden_size, kernel_size=patch_size, stride=patch_size,
+ dtype=dtype, device=device))
+
+ self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
+ self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
+ self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+
+ if layer_types is None:
+ layer_types = ["full_attention"] * num_layers
+
+ if len(layer_types) < num_layers:
+ layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers))
+
+ self.layers = nn.ModuleList([
+ AceStepDiTLayer(
+ hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ layer_type=layer_types[i],
+ sliding_window=sliding_window,
+ dtype=dtype, device=device, operations=operations
+ ) for i in range(num_layers)
+ ])
+
+ self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.proj_out = nn.Sequential(
+ nn.Identity(),
+ ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device)
+ )
+
+ self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states,
+ timestep,
+ timestep_r,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ context_latents
+ ):
+ temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype)
+ temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype)
+ temb = temb_t + temb_r
+ timestep_proj = proj_t + proj_r
+
+ x = torch.cat([context_latents, hidden_states], dim=-1)
+ original_seq_len = x.shape[1]
+
+ pad_length = 0
+ if x.shape[1] % self.patch_size != 0:
+ pad_length = self.patch_size - (x.shape[1] % self.patch_size)
+ x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0)
+
+ x = x.transpose(1, 2)
+ x = self.proj_in(x)
+ x = x.transpose(1, 2)
+
+ encoder_hidden_states = self.condition_embedder(encoder_hidden_states)
+
+ seq_len = x.shape[1]
+ cos, sin = self.rotary_emb(x, seq_len=seq_len)
+
+ for layer in self.layers:
+ x = layer(
+ hidden_states=x,
+ temb=timestep_proj,
+ encoder_hidden_states=encoder_hidden_states,
+ position_embeddings=(cos, sin),
+ attention_mask=None,
+ encoder_attention_mask=None
+ )
+
+ shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
+ x = self.norm_out(x) * (1 + scale) + shift
+
+ x = x.transpose(1, 2)
+ x = self.proj_out(x)
+ x = x.transpose(1, 2)
+
+ x = x[:, :original_seq_len, :]
+ return x
+
+
+class AttentionPooler(nn.Module):
+ def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+ self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device)
+ self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
+ self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
+
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, x):
+ B, T, P, D = x.shape
+ x = self.embed_tokens(x)
+ special = self.special_token.expand(B, T, 1, -1)
+ x = torch.cat([special, x], dim=2)
+ x = x.view(B * T, P + 1, D)
+
+ cos, sin = self.rotary_emb(x, seq_len=P + 1)
+ for layer in self.layers:
+ x = layer(x, (cos, sin))
+
+ x = self.norm(x)
+ return x[:, 0, :].view(B, T, D)
+
+
+class FSQ(nn.Module):
+ def __init__(
+ self,
+ levels,
+ dim=None,
+ device=None,
+ dtype=None,
+ operations=None
+ ):
+ super().__init__()
+
+ _levels = torch.tensor(levels, dtype=torch.int32, device=device)
+ self.register_buffer('_levels', _levels, persistent=False)
+
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0)
+ self.register_buffer('_basis', _basis, persistent=False)
+
+ self.codebook_dim = len(levels)
+ self.dim = dim if dim is not None else self.codebook_dim
+
+ requires_projection = self.dim != self.codebook_dim
+ if requires_projection:
+ self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype)
+ self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype)
+ else:
+ self.project_in = nn.Identity()
+ self.project_out = nn.Identity()
+
+ self.codebook_size = self._levels.prod().item()
+
+ indices = torch.arange(self.codebook_size, device=device)
+ implicit_codebook = self._indices_to_codes(indices)
+
+ if dtype is not None:
+ implicit_codebook = implicit_codebook.to(dtype)
+
+ self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
+
+ def bound(self, z):
+ levels_minus_1 = (self._levels - 1).to(z.dtype)
+ scale = 2. / levels_minus_1
+ bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
+
+ zhat = bracket.floor()
+ bracket_ste = bracket + (zhat - bracket).detach()
+
+ return scale * bracket_ste - 1.
+
+ def _indices_to_codes(self, indices):
+ indices = indices.unsqueeze(-1)
+ codes_non_centered = (indices // self._basis) % self._levels
+ return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
+
+ def codes_to_indices(self, zhat):
+ zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
+ return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
+
+ def forward(self, z):
+ orig_dtype = z.dtype
+ z = self.project_in(z)
+
+ codes = self.bound(z)
+ indices = self.codes_to_indices(codes)
+
+ out = self.project_out(codes)
+ return out.to(orig_dtype), indices
+
+
+class ResidualFSQ(nn.Module):
+ def __init__(
+ self,
+ levels,
+ num_quantizers,
+ dim=None,
+ bound_hard_clamp=True,
+ device=None,
+ dtype=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+
+ codebook_dim = len(levels)
+ dim = dim if dim is not None else codebook_dim
+
+ requires_projection = codebook_dim != dim
+ if requires_projection:
+ self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype)
+ self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype)
+ else:
+ self.project_in = nn.Identity()
+ self.project_out = nn.Identity()
+
+ self.layers = nn.ModuleList()
+ levels_tensor = torch.tensor(levels, device=device)
+ scales = []
+
+ for ind in range(num_quantizers):
+ scale_val = levels_tensor.float() ** -ind
+ scales.append(scale_val)
+
+ self.layers.append(FSQ(
+ levels=levels,
+ dim=codebook_dim,
+ device=device,
+ dtype=dtype,
+ operations=operations
+ ))
+
+ scales_tensor = torch.stack(scales)
+ if dtype is not None:
+ scales_tensor = scales_tensor.to(dtype)
+ self.register_buffer('scales', scales_tensor, persistent=False)
+
+ if bound_hard_clamp:
+ val = 1 + (1 / (levels_tensor.float() - 1))
+ if dtype is not None:
+ val = val.to(dtype)
+ self.register_buffer('soft_clamp_input_value', val, persistent=False)
+
+ def get_output_from_indices(self, indices, dtype=torch.float32):
+ if indices.dim() == 2:
+ indices = indices.unsqueeze(-1)
+
+ all_codes = []
+ for i, layer in enumerate(self.layers):
+ idx = indices[..., i].long()
+ codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype))
+ all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype))
+
+ codes_summed = torch.stack(all_codes).sum(dim=0)
+ return self.project_out(codes_summed)
+
+ def forward(self, x):
+ x = self.project_in(x)
+
+ if hasattr(self, 'soft_clamp_input_value'):
+ sc_val = self.soft_clamp_input_value.to(x.dtype)
+ x = (x / sc_val).tanh() * sc_val
+
+ quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
+ residual = x
+ all_indices = []
+
+ for layer, scale in zip(self.layers, self.scales):
+ scale = scale.to(residual.dtype)
+
+ quantized, indices = layer(residual / scale)
+ quantized = quantized * scale
+
+ residual = residual - quantized.detach()
+ quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+
+ quantized_out = self.project_out(quantized_out)
+ all_indices = torch.stack(all_indices, dim=-1)
+
+ return quantized_out, all_indices
+
+
+class AceStepAudioTokenizer(nn.Module):
+ def __init__(
+ self,
+ audio_acoustic_hidden_dim,
+ hidden_size,
+ pool_window_size,
+ fsq_dim,
+ fsq_levels,
+ fsq_input_num_quantizers,
+ num_layers,
+ head_dim,
+ rms_norm_eps,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device)
+ self.attention_pooler = AttentionPooler(
+ hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations
+ )
+ self.pool_window_size = pool_window_size
+ self.fsq_dim = fsq_dim
+ self.quantizer = ResidualFSQ(
+ dim=fsq_dim,
+ levels=fsq_levels,
+ num_quantizers=fsq_input_num_quantizers,
+ bound_hard_clamp=True,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ def forward(self, hidden_states):
+ hidden_states = self.audio_acoustic_proj(hidden_states)
+ hidden_states = self.attention_pooler(hidden_states)
+ quantized, indices = self.quantizer(hidden_states)
+ return quantized, indices
+
+ def tokenize(self, x):
+ B, T, D = x.shape
+ P = self.pool_window_size
+
+ if T % P != 0:
+ pad = P - (T % P)
+ x = F.pad(x, (0, 0, 0, pad))
+ T = x.shape[1]
+
+ T_patch = T // P
+ x = x.view(B, T_patch, P, D)
+
+ quantized, indices = self.forward(x)
+ return quantized, indices
+
+
+class AudioTokenDetokenizer(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ pool_window_size,
+ audio_acoustic_hidden_dim,
+ num_layers,
+ head_dim,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ Linear = get_layer_class(operations, "Linear")
+ self.pool_window_size = pool_window_size
+ self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
+ self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device))
+ self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations)
+ self.layers = nn.ModuleList([
+ AceStepEncoderLayer(
+ hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(num_layers)
+ ])
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
+ self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device)
+
+ def forward(self, x):
+ B, T, D = x.shape
+ x = self.embed_tokens(x)
+ x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1)
+ x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype)
+ x = x.view(B * T, self.pool_window_size, D)
+
+ cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size)
+ for layer in self.layers:
+ x = layer(x, (cos, sin))
+
+ x = self.norm(x)
+ x = self.proj_out(x)
+ return x.view(B, T * self.pool_window_size, -1)
+
+
+class AceStepConditionGenerationModel(nn.Module):
+ def __init__(
+ self,
+ in_channels=192,
+ hidden_size=2048,
+ text_hidden_dim=1024,
+ timbre_hidden_dim=64,
+ audio_acoustic_hidden_dim=64,
+ num_dit_layers=24,
+ num_lyric_layers=8,
+ num_timbre_layers=4,
+ num_tokenizer_layers=2,
+ num_heads=16,
+ num_kv_heads=8,
+ head_dim=128,
+ intermediate_size=6144,
+ patch_size=2,
+ pool_window_size=5,
+ rms_norm_eps=1e-06,
+ timestep_mu=-0.4,
+ timestep_sigma=1.0,
+ data_proportion=0.5,
+ sliding_window=128,
+ layer_types=None,
+ fsq_dim=2048,
+ fsq_levels=[8, 8, 8, 5, 5, 5],
+ fsq_input_num_quantizers=1,
+ audio_model=None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.timestep_mu = timestep_mu
+ self.timestep_sigma = timestep_sigma
+ self.data_proportion = data_proportion
+ self.pool_window_size = pool_window_size
+
+ if layer_types is None:
+ layer_types = []
+ for i in range(num_dit_layers):
+ layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention")
+
+ self.decoder = AceStepDiTModel(
+ in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
+ intermediate_size, patch_size, audio_acoustic_hidden_dim,
+ layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.encoder = AceStepConditionEncoder(
+ text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
+ num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.tokenizer = AceStepAudioTokenizer(
+ audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.detokenizer = AudioTokenDetokenizer(
+ hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
+
+ def prepare_condition(
+ self,
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
+ src_latents, chunk_masks, is_covers,
+ precomputed_lm_hints_25Hz=None,
+ audio_codes=None
+ ):
+ encoder_hidden, encoder_mask = self.encoder(
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask
+ )
+
+ if precomputed_lm_hints_25Hz is not None:
+ lm_hints = precomputed_lm_hints_25Hz
+ else:
+ if audio_codes is not None:
+ if audio_codes.shape[1] * 5 < src_latents.shape[1]:
+ audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847)
+ lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype)
+ else:
+ assert False
+ # TODO ?
+
+ lm_hints = self.detokenizer(lm_hints_5Hz)
+
+ lm_hints = lm_hints[:, :src_latents.shape[1], :]
+ if is_covers is None:
+ src_latents = lm_hints
+ else:
+ src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents)
+
+ context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1)
+
+ return encoder_hidden, encoder_mask, context_latents
+
+ def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs):
+ text_attention_mask = None
+ lyric_attention_mask = None
+ refer_audio_order_mask = None
+ attention_mask = None
+ chunk_masks = None
+ is_covers = None
+ src_latents = None
+ precomputed_lm_hints_25Hz = None
+ lyric_hidden_states = lyric_embed
+ text_hidden_states = context
+ refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2)
+
+ x = x.movedim(-1, -2)
+
+ if refer_audio_order_mask is None:
+ refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long)
+
+ if src_latents is None and is_covers is None:
+ src_latents = x
+
+ if chunk_masks is None:
+ chunk_masks = torch.ones_like(x)
+
+ enc_hidden, enc_mask, context_latents = self.prepare_condition(
+ text_hidden_states, text_attention_mask,
+ lyric_hidden_states, lyric_attention_mask,
+ refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask,
+ src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
+ )
+
+ out = self.decoder(hidden_states=x,
+ timestep=timestep,
+ timestep_r=timestep,
+ attention_mask=attention_mask,
+ encoder_hidden_states=enc_hidden,
+ encoder_attention_mask=enc_mask,
+ context_latents=context_latents
+ )
+
+ return out.movedim(-1, -2)
diff --git a/comfy/lora.py b/comfy/lora.py
index 7b31d055c..44030bcab 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
+ if isinstance(model, comfy.model_base.ACEStep15):
+ for k in sdk:
+ if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model.decoder."):-len(".weight")]
+ key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
+
return key_map
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 85acdb66a..89944548c 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
+import comfy.ldm.ace.ace_step15
import comfy.model_management
import comfy.patcher_extension
@@ -1540,6 +1541,47 @@ class ACEStep(BaseModel):
out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
return out
+class ACEStep15(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ device = kwargs["device"]
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
+ if cross_attn is not None:
+ out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics)
+
+ refer_audio = kwargs.get("reference_audio_timbre_latents", None)
+ if refer_audio is None or len(refer_audio) == 0:
+ refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02,
+ 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01,
+ -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00,
+ 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00,
+ 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01,
+ -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00,
+ 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01,
+ -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01,
+ 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01,
+ 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01,
+ -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01,
+ -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01,
+ 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750)
+ else:
+ refer_audio = refer_audio[-1]
+ out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
+
+ audio_codes = kwargs.get("audio_codes", None)
+ if audio_codes is not None:
+ out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device))
+
+ return out
+
class Omnigen2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel)
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 8cea16e50..e8ad725df 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
+ if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
+ dit_config = {}
+ dit_config["audio_model"] = "ace1.5"
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
diff --git a/comfy/model_management.py b/comfy/model_management.py
index cd035f017..72348258b 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -767,6 +767,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
+def load_model_gpu(model):
+ return load_models_gpu([model])
+
def loaded_models(only_currently_used=False):
output = []
for m in current_loaded_models:
diff --git a/comfy/sd.py b/comfy/sd.py
index fd0ac85e7..bc63d6ced 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
import comfy.text_encoders.newbie
import comfy.text_encoders.anima
+import comfy.text_encoders.ace15
import comfy.model_patcher
import comfy.lora
@@ -452,6 +453,8 @@ class VAE:
self.extra_1d_channel = None
self.crop_input = True
+ self.audio_sample_rate = 44100
+
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@@ -549,14 +552,27 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig})
elif "decoder.layers.1.layers.0.beta" in sd:
- self.first_stage_model = AudioOobleckVAE()
+ config = {}
+ param_key = None
+ self.upscale_ratio = 2048
+ self.downscale_ratio = 2048
+ if "decoder.layers.2.layers.1.weight_v" in sd:
+ param_key = "decoder.layers.2.layers.1.weight_v"
+ if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd:
+ param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1"
+ if param_key is not None:
+ if sd[param_key].shape[-1] == 12:
+ config["strides"] = [2, 4, 4, 6, 10]
+ self.audio_sample_rate = 48000
+ self.upscale_ratio = 1920
+ self.downscale_ratio = 1920
+
+ self.first_stage_model = AudioOobleckVAE(**config)
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype)
self.latent_channels = 64
self.output_channels = 2
self.pad_channel_value = "replicate"
- self.upscale_ratio = 2048
- self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
@@ -856,7 +872,7 @@ class VAE:
/ 3.0)
return output
- def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
+ def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
else:
@@ -1427,6 +1443,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
+ elif clip_type == CLIPType.ACE:
+ te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])]
+ if TEModel.QWEN3_4B in te_models:
+ model_type = "qwen3_4b"
+ else:
+ model_type = "qwen3_2b"
+ clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 9ecfc9c55..4c817d468 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
+ elif isinstance(layer_idx, list):
+ self.layer = layer_idx
elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index d25271d6e..77264ed28 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.anima
+import comfy.text_encoders.ace15
from . import supported_models_base
from . import latent_formats
@@ -1596,6 +1597,46 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
+class ACEStep15(supported_models_base.BASE):
+ unet_config = {
+ "audio_model": "ace1.5",
+ }
+
+ unet_extra_config = {
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 3.0,
+ }
+
+ latent_format = comfy.latent_formats.ACEAudio15
+
+ memory_usage_factor = 4.7
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.ACEStep15(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref))
+ detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
+ if "dtype_llama" in detect_2b:
+ detect = detect_2b
+ detect["lm_model"] = "qwen3_2b"
+ elif "dtype_llama" in detect_4b:
+ detect = detect_4b
+ detect["lm_model"] = "qwen3_4b"
+
+ return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
+
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
new file mode 100644
index 000000000..73d710671
--- /dev/null
+++ b/comfy/text_encoders/ace15.py
@@ -0,0 +1,247 @@
+from .anima import Qwen3Tokenizer
+import comfy.text_encoders.llama
+from comfy import sd1_clip
+import torch
+import math
+import comfy.utils
+
+
+def sample_manual_loop_no_classes(
+ model,
+ ids=None,
+ paddings=[],
+ execution_dtype=None,
+ cfg_scale: float = 2.0,
+ temperature: float = 0.85,
+ top_p: float = 0.9,
+ top_k: int = None,
+ seed: int = 1,
+ min_tokens: int = 1,
+ max_new_tokens: int = 2048,
+ audio_start_id: int = 151669, # The cutoff ID for audio codes
+ eos_token_id: int = 151645,
+):
+ device = model.execution_device
+
+ if execution_dtype is None:
+ if comfy.model_management.should_use_bf16(device):
+ execution_dtype = torch.bfloat16
+ else:
+ execution_dtype = torch.float32
+
+ embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
+ for i, t in enumerate(paddings):
+ attention_mask[i, :t] = 0
+ attention_mask[i, t:] = 1
+
+ output_audio_codes = []
+ past_key_values = []
+ generator = torch.Generator(device=device)
+ generator.manual_seed(seed)
+ model_config = model.transformer.model.config
+
+ for x in range(model_config.num_hidden_layers):
+ past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+
+ progress_bar = comfy.utils.ProgressBar(max_new_tokens)
+
+ for step in range(max_new_tokens):
+ outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
+ next_token_logits = model.transformer.logits(outputs[0])[:, -1]
+ past_key_values = outputs[2]
+
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+
+ if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ eos_score = cfg_logits[:, eos_token_id].clone()
+
+ remove_logit_value = torch.finfo(cfg_logits.dtype).min
+ # Only generate audio tokens
+ cfg_logits[:, :audio_start_id] = remove_logit_value
+
+ if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ cfg_logits[:, eos_token_id] = eos_score
+
+ if top_k is not None and top_k > 0:
+ top_k_vals, _ = torch.topk(cfg_logits, top_k)
+ min_val = top_k_vals[..., -1, None]
+ cfg_logits[cfg_logits < min_val] = remove_logit_value
+
+ if top_p is not None and top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+ cfg_logits[indices_to_remove] = remove_logit_value
+
+ if temperature > 0:
+ cfg_logits = cfg_logits / temperature
+ next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1)
+ else:
+ next_token = torch.argmax(cfg_logits, dim=-1)
+
+ token = next_token.item()
+
+ if token == eos_token_id:
+ break
+
+ embed, _, _, _ = model.process_tokens([[token]], device)
+ embeds = embed.repeat(2, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+
+ output_audio_codes.append(token - audio_start_id)
+ progress_bar.update_absolute(step)
+
+ return output_audio_codes
+
+
+def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0):
+ cfg_scale = 2.0
+
+ positive = [[token for token, _ in inner_list] for inner_list in positive]
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ positive = positive[0]
+ negative = negative[0]
+
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
+
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ paddings = [pos_pad, neg_pad]
+ return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+
+
+class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer)
+
+ def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
+ out = {}
+ lyrics = kwargs.get("lyrics", "")
+ bpm = kwargs.get("bpm", 120)
+ duration = kwargs.get("duration", 120)
+ keyscale = kwargs.get("keyscale", "C major")
+ timesignature = kwargs.get("timesignature", 2)
+ language = kwargs.get("language", "en")
+ seed = kwargs.get("seed", 0)
+
+ duration = math.ceil(duration)
+ meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature)
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n"
+
+ meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration)
+ out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True)
+ out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True)
+
+ out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs)
+ out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
+ out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed}
+ return out
+
+
+class Qwen3_06BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class Qwen3_2B_ACE15(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class Qwen3_4B_ACE15(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+class ACE15TEModel(torch.nn.Module):
+ def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}):
+ super().__init__()
+ if dtype_llama is None:
+ dtype_llama = dtype
+
+ model = None
+ self.constant = 0.4375
+ if lm_model == "qwen3_4b":
+ model = Qwen3_4B_ACE15
+ self.constant = 0.5625
+ elif lm_model == "qwen3_2b":
+ model = Qwen3_2B_ACE15
+
+ self.lm_model = lm_model
+ self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options)
+ if model is not None:
+ setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options))
+
+ self.dtypes = set([dtype, dtype_llama])
+
+ def encode_token_weights(self, token_weight_pairs):
+ token_weight_pairs_base = token_weight_pairs["qwen3_06b"]
+ token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
+
+ self.qwen3_06b.set_clip_options({"layer": None})
+ base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base)
+ self.qwen3_06b.set_clip_options({"layer": [0]})
+ lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics)
+
+ lm_metadata = token_weight_pairs["lm_metadata"]
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"])
+
+ return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]}
+
+ def set_clip_options(self, options):
+ self.qwen3_06b.set_clip_options(options)
+ lm_model = getattr(self, self.lm_model, None)
+ if lm_model is not None:
+ lm_model.set_clip_options(options)
+
+ def reset_clip_options(self):
+ self.qwen3_06b.reset_clip_options()
+ lm_model = getattr(self, self.lm_model, None)
+ if lm_model is not None:
+ lm_model.reset_clip_options()
+
+ def load_sd(self, sd):
+ if "model.layers.0.post_attention_layernorm.weight" in sd:
+ shape = sd["model.layers.0.post_attention_layernorm.weight"].shape
+ if shape[0] == 1024:
+ return self.qwen3_06b.load_sd(sd)
+ else:
+ return getattr(self, self.lm_model).load_sd(sd)
+
+ def memory_estimation_function(self, token_weight_pairs, device=None):
+ lm_metadata = token_weight_pairs["lm_metadata"]
+ constant = self.constant
+ if comfy.model_management.should_use_bf16(device):
+ constant *= 0.5
+
+ token_weight_pairs = token_weight_pairs.get("lm_prompt", [])
+ num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
+ num_tokens += lm_metadata['min_tokens']
+ return num_tokens * constant * 1024 * 1024
+
+def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"):
+ class ACE15TEModel_(ACE15TEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options)
+ return ACE15TEModel_
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index 68ac1e804..3afd094d1 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -6,6 +6,7 @@ import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
+import comfy.ops
import comfy.ldm.common_dit
import comfy.clip_model
@@ -103,6 +104,75 @@ class Qwen3_06BConfig:
final_norm: bool = True
lm_head: bool = False
+@dataclass
+class Qwen3_06B_ACE15_Config:
+ vocab_size: int = 151669
+ hidden_size: int = 1024
+ intermediate_size: int = 3072
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 32768
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
+
+@dataclass
+class Qwen3_2B_ACE15_lm_Config:
+ vocab_size: int = 217204
+ hidden_size: int = 2048
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
+
+@dataclass
+class Qwen3_4B_ACE15_lm_Config:
+ vocab_size: int = 217204
+ hidden_size: int = 2560
+ intermediate_size: int = 9728
+ num_hidden_layers: int = 36
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+ lm_head: bool = False
+
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
@@ -581,10 +651,10 @@ class Llama2_(nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+ mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min)
if seq_len > 1:
- causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
+ causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1)
if mask is not None:
mask += causal_mask
else:
@@ -692,6 +762,21 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
+class BaseQwen3:
+ def logits(self, x):
+ input = x[:, -1:]
+ module = self.model.embed_tokens
+
+ offload_stream = None
+ if module.comfy_cast_weights:
+ weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
+ else:
+ weight = self.model.embed_tokens.weight.to(x)
+
+ x = torch.nn.functional.linear(input, weight, None)
+
+ comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
+ return x
class Llama2(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
@@ -720,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_06B(BaseLlama, torch.nn.Module):
+class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_06BConfig(**config_dict)
@@ -729,7 +814,25 @@ class Qwen3_06B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_4B(BaseLlama, torch.nn.Module):
+class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_06B_ACE15_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_2B_ACE15_lm_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
@@ -738,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen3_8B(BaseLlama, torch.nn.Module):
+class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Qwen3_4B_ACE15_lm_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_8BConfig(**config_dict)
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index b0fa14ff6..8542a1dbc 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
-from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
+from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@@ -105,6 +105,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
+ File3D = File3D
ComfyAPI = ComfyAPI_latest
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index eeea9781a..93cf482ca 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
-from ._util import MESH, VOXEL, SVG as _SVG
+from ._util import MESH, VOXEL, SVG as _SVG, File3D
class FolderType(str, Enum):
@@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
+
+@comfytype(io_type="FILE_3D")
+class File3DAny(ComfyTypeIO):
+ """General 3D file type - accepts any supported 3D format."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_GLB")
+class File3DGLB(ComfyTypeIO):
+ """GLB format 3D file - binary glTF, best for web and cross-platform."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_GLTF")
+class File3DGLTF(ComfyTypeIO):
+ """GLTF format 3D file - JSON-based glTF with external resources."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_FBX")
+class File3DFBX(ComfyTypeIO):
+ """FBX format 3D file - best for game engines and animation."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_OBJ")
+class File3DOBJ(ComfyTypeIO):
+ """OBJ format 3D file - simple geometry format."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_STL")
+class File3DSTL(ComfyTypeIO):
+ """STL format 3D file - best for 3D printing."""
+ Type = File3D
+
+
+@comfytype(io_type="FILE_3D_USDZ")
+class File3DUSDZ(ComfyTypeIO):
+ """USDZ format 3D file - Apple AR format."""
+ Type = File3D
+
+
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@@ -2037,6 +2080,13 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
+ "File3DAny",
+ "File3DGLB",
+ "File3DGLTF",
+ "File3DFBX",
+ "File3DOBJ",
+ "File3DSTL",
+ "File3DUSDZ",
"Hooks",
"HookKeyframes",
"TimestepsRange",
diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py
index 6313eb01b..115baf392 100644
--- a/comfy_api/latest/_util/__init__.py
+++ b/comfy_api/latest/_util/__init__.py
@@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
-from .geometry_types import VOXEL, MESH
+from .geometry_types import VOXEL, MESH, File3D
from .image_types import SVG
__all__ = [
@@ -9,5 +9,6 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
+ "File3D",
"SVG",
]
diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py
index 385122778..b586fceb3 100644
--- a/comfy_api/latest/_util/geometry_types.py
+++ b/comfy_api/latest/_util/geometry_types.py
@@ -1,3 +1,8 @@
+import shutil
+from io import BytesIO
+from pathlib import Path
+from typing import IO
+
import torch
@@ -10,3 +15,75 @@ class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces
+
+
+class File3D:
+ """Class representing a 3D file from a file path or binary stream.
+
+ Supports both disk-backed (file path) and memory-backed (BytesIO) storage.
+ """
+
+ def __init__(self, source: str | IO[bytes], file_format: str = ""):
+ self._source = source
+ self._format = file_format or self._infer_format()
+
+ def _infer_format(self) -> str:
+ if isinstance(self._source, str):
+ return Path(self._source).suffix.lstrip(".").lower()
+ return ""
+
+ @property
+ def format(self) -> str:
+ return self._format
+
+ @format.setter
+ def format(self, value: str) -> None:
+ self._format = value.lstrip(".").lower() if value else ""
+
+ @property
+ def is_disk_backed(self) -> bool:
+ return isinstance(self._source, str)
+
+ def get_source(self) -> str | IO[bytes]:
+ if isinstance(self._source, str):
+ return self._source
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ return self._source
+
+ def get_data(self) -> BytesIO:
+ if isinstance(self._source, str):
+ with open(self._source, "rb") as f:
+ result = BytesIO(f.read())
+ return result
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ if isinstance(self._source, BytesIO):
+ return self._source
+ return BytesIO(self._source.read())
+
+ def save_to(self, path: str) -> str:
+ dest = Path(path)
+ dest.parent.mkdir(parents=True, exist_ok=True)
+
+ if isinstance(self._source, str):
+ if Path(self._source).resolve() != dest.resolve():
+ shutil.copy2(self._source, dest)
+ else:
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ with open(dest, "wb") as f:
+ f.write(self._source.read())
+ return str(dest)
+
+ def get_bytes(self) -> bytes:
+ if isinstance(self._source, str):
+ return Path(self._source).read_bytes()
+ if hasattr(self._source, "seek"):
+ self._source.seek(0)
+ return self._source.read()
+
+ def __repr__(self) -> str:
+ if isinstance(self._source, str):
+ return f"File3D(source={self._source!r}, format={self._format!r})"
+ return f"File3D(, format={self._format!r})"
diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py
new file mode 100644
index 000000000..b23c5d9eb
--- /dev/null
+++ b/comfy_api_nodes/apis/hitpaw.py
@@ -0,0 +1,51 @@
+from typing import TypedDict
+
+from pydantic import BaseModel, Field
+
+
+class InputVideoModel(TypedDict):
+ model: str
+ resolution: str
+
+
+class ImageEnhanceTaskCreateRequest(BaseModel):
+ model_name: str = Field(...)
+ img_url: str = Field(...)
+ extension: str = Field(".png")
+ exif: bool = Field(False)
+ DPI: int | None = Field(None)
+
+
+class VideoEnhanceTaskCreateRequest(BaseModel):
+ video_url: str = Field(...)
+ extension: str = Field(".mp4")
+ model_name: str | None = Field(...)
+ resolution: list[int] = Field(..., description="Target resolution [width, height]")
+ original_resolution: list[int] = Field(..., description="Original video resolution [width, height]")
+
+
+class TaskCreateDataResponse(BaseModel):
+ job_id: str = Field(...)
+ consume_coins: int | None = Field(None)
+
+
+class TaskStatusPollRequest(BaseModel):
+ job_id: str = Field(...)
+
+
+class TaskCreateResponse(BaseModel):
+ code: int = Field(...)
+ message: str = Field(...)
+ data: TaskCreateDataResponse | None = Field(None)
+
+
+class TaskStatusDataResponse(BaseModel):
+ job_id: str = Field(...)
+ status: str = Field(...)
+ res_url: str = Field("")
+
+
+class TaskStatusResponse(BaseModel):
+ code: int = Field(...)
+ message: str = Field(...)
+ data: TaskStatusDataResponse = Field(...)
diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py
index be46d0d58..7d72e6e91 100644
--- a/comfy_api_nodes/apis/meshy.py
+++ b/comfy_api_nodes/apis/meshy.py
@@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel):
class MeshyModelsUrls(BaseModel):
glb: str = Field("")
+ fbx: str = Field("")
+ usdz: str = Field("")
+ obj: str = Field("")
class MeshyRiggedModelsUrls(BaseModel):
rigged_character_glb_url: str = Field("")
+ rigged_character_fbx_url: str = Field("")
class MeshyAnimatedModelsUrls(BaseModel):
animation_glb_url: str = Field("")
+ animation_fbx_url: str = Field("")
class MeshyResultTextureUrls(BaseModel):
diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py
new file mode 100644
index 000000000..488080a74
--- /dev/null
+++ b/comfy_api_nodes/nodes_hitpaw.py
@@ -0,0 +1,342 @@
+import math
+
+from typing_extensions import override
+
+from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api_nodes.apis.hitpaw import (
+ ImageEnhanceTaskCreateRequest,
+ InputVideoModel,
+ TaskCreateDataResponse,
+ TaskCreateResponse,
+ TaskStatusPollRequest,
+ TaskStatusResponse,
+ VideoEnhanceTaskCreateRequest,
+)
+from comfy_api_nodes.util import (
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ download_url_to_video_output,
+ downscale_image_tensor,
+ get_image_dimensions,
+ poll_op,
+ sync_op,
+ upload_image_to_comfyapi,
+ upload_video_to_comfyapi,
+ validate_video_duration,
+)
+
+VIDEO_MODELS_MODELS_MAP = {
+ "Portrait Restore Model (1x)": "portrait_restore_1x",
+ "Portrait Restore Model (2x)": "portrait_restore_2x",
+ "General Restore Model (1x)": "general_restore_1x",
+ "General Restore Model (2x)": "general_restore_2x",
+ "General Restore Model (4x)": "general_restore_4x",
+ "Ultra HD Model (2x)": "ultrahd_restore_2x",
+ "Generative Model (1x)": "generative_1x",
+}
+
+# Resolution name to target dimension (shorter side) in pixels
+RESOLUTION_TARGET_MAP = {
+ "720p": 720,
+ "1080p": 1080,
+ "2K/QHD": 1440,
+ "4K/UHD": 2160,
+ "8K": 4320,
+}
+
+# Square (1:1) resolutions use standard square dimensions
+RESOLUTION_SQUARE_MAP = {
+ "720p": 720,
+ "1080p": 1080,
+ "2K/QHD": 1440,
+ "4K/UHD": 2048, # DCI 4K square
+ "8K": 4096, # DCI 8K square
+}
+
+# Models with limited resolution support (no 8K)
+LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"}
+
+# Resolution options for different model types
+RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"]
+RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"]
+
+# Maximum output resolution in pixels
+MAX_PIXELS_GENERATIVE = 32_000_000
+MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000
+
+
+class HitPawGeneralImageEnhance(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="HitPawGeneralImageEnhance",
+ display_name="HitPaw General Image Enhance",
+ category="api node/image/HitPaw",
+ description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
+ f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
+ inputs=[
+ IO.Combo.Input("model", options=["generative_portrait", "generative"]),
+ IO.Image.Input("image"),
+ IO.Combo.Input("upscale_factor", options=[1, 2, 4]),
+ IO.Boolean.Input(
+ "auto_downscale",
+ default=False,
+ tooltip="Automatically downscale input image if output would exceed the limit.",
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model"]),
+ expr="""
+ (
+ $prices := {
+ "generative_portrait": {"min": 0.02, "max": 0.06},
+ "generative": {"min": 0.05, "max": 0.15}
+ };
+ $price := $lookup($prices, widgets.model);
+ {
+ "type": "range_usd",
+ "min_usd": $price.min,
+ "max_usd": $price.max
+ }
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: str,
+ image: Input.Image,
+ upscale_factor: int,
+ auto_downscale: bool,
+ ) -> IO.NodeOutput:
+ height, width = get_image_dimensions(image)
+ requested_scale = upscale_factor
+ output_pixels = height * width * requested_scale * requested_scale
+ if output_pixels > MAX_PIXELS_GENERATIVE:
+ if auto_downscale:
+ input_pixels = width * height
+ scale = 1
+ max_input_pixels = MAX_PIXELS_GENERATIVE
+
+ for candidate in [4, 2, 1]:
+ if candidate > requested_scale:
+ continue
+ scale_output_pixels = input_pixels * candidate * candidate
+ if scale_output_pixels <= MAX_PIXELS_GENERATIVE:
+ scale = candidate
+ max_input_pixels = None
+ break
+ # Check if we can downscale input by at most 2x to fit
+ downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE)
+ if downscale_ratio <= 2.0:
+ scale = candidate
+ max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate)
+ break
+
+ if max_input_pixels is not None:
+ image = downscale_image_tensor(image, total_pixels=max_input_pixels)
+ upscale_factor = scale
+ else:
+ output_width = width * requested_scale
+ output_height = height * requested_scale
+ raise ValueError(
+ f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) "
+ f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). "
+ f"Enable auto_downscale or use a smaller input image or a lower upscale factor."
+ )
+
+ initial_res = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"),
+ response_model=TaskCreateResponse,
+ data=ImageEnhanceTaskCreateRequest(
+ model_name=f"{model}_{upscale_factor}x",
+ img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None),
+ ),
+ wait_label="Creating task",
+ final_label_on_success="Task created",
+ )
+ if initial_res.code != 200:
+ raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
+ request_price = initial_res.data.consume_coins / 1000
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
+ data=TaskCreateDataResponse(job_id=initial_res.data.job_id),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda x: x.data.status,
+ price_extractor=lambda x: request_price,
+ poll_interval=10.0,
+ max_poll_attempts=480,
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
+
+
+class HitPawVideoEnhance(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ model_options = []
+ for model_name in VIDEO_MODELS_MODELS_MAP:
+ if model_name in LIMITED_RESOLUTION_MODELS:
+ resolutions = RESOLUTIONS_LIMITED
+ else:
+ resolutions = RESOLUTIONS_FULL
+ model_options.append(
+ IO.DynamicCombo.Option(
+ model_name,
+ [IO.Combo.Input("resolution", options=resolutions)],
+ )
+ )
+
+ return IO.Schema(
+ node_id="HitPawVideoEnhance",
+ display_name="HitPaw Video Enhance",
+ category="api node/video/HitPaw",
+ description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
+ "Prices shown are per second of video.",
+ inputs=[
+ IO.DynamicCombo.Input("model", options=model_options),
+ IO.Video.Input("video"),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
+ expr="""
+ (
+ $m := $lookup(widgets, "model");
+ $res := $lookup(widgets, "model.resolution");
+ $standard_model_prices := {
+ "original": {"min": 0.01, "max": 0.198},
+ "720p": {"min": 0.01, "max": 0.06},
+ "1080p": {"min": 0.015, "max": 0.09},
+ "2k/qhd": {"min": 0.02, "max": 0.117},
+ "4k/uhd": {"min": 0.025, "max": 0.152},
+ "8k": {"min": 0.033, "max": 0.198}
+ };
+ $ultra_hd_model_prices := {
+ "original": {"min": 0.015, "max": 0.264},
+ "720p": {"min": 0.015, "max": 0.092},
+ "1080p": {"min": 0.02, "max": 0.12},
+ "2k/qhd": {"min": 0.026, "max": 0.156},
+ "4k/uhd": {"min": 0.034, "max": 0.203},
+ "8k": {"min": 0.044, "max": 0.264}
+ };
+ $generative_model_prices := {
+ "original": {"min": 0.015, "max": 0.338},
+ "720p": {"min": 0.008, "max": 0.090},
+ "1080p": {"min": 0.05, "max": 0.15},
+ "2k/qhd": {"min": 0.038, "max": 0.225},
+ "4k/uhd": {"min": 0.056, "max": 0.338}
+ };
+ $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices :
+ $contains($m, "generative") ? $generative_model_prices :
+ $standard_model_prices;
+ $price := $lookup($prices, $res);
+ {
+ "type": "range_usd",
+ "min_usd": $price.min,
+ "max_usd": $price.max,
+ "format": {"approximate": true, "suffix": "/second"}
+ }
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: InputVideoModel,
+ video: Input.Video,
+ ) -> IO.NodeOutput:
+ validate_video_duration(video, min_duration=0.5, max_duration=60 * 60)
+ resolution = model["resolution"]
+ src_width, src_height = video.get_dimensions()
+
+ if resolution == "original":
+ output_width = src_width
+ output_height = src_height
+ else:
+ if src_width == src_height:
+ target_size = RESOLUTION_SQUARE_MAP[resolution]
+ if target_size < src_width:
+ raise ValueError(
+ f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than "
+ f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'."
+ )
+ output_width = target_size
+ output_height = target_size
+ else:
+ min_dimension = min(src_width, src_height)
+ target_size = RESOLUTION_TARGET_MAP[resolution]
+ if target_size < min_dimension:
+ raise ValueError(
+ f"Selected resolution {resolution} ({target_size}p) is smaller than "
+ f"the input video's shorter dimension ({min_dimension}p). "
+ f"Please select a higher resolution or 'original'."
+ )
+ if src_width > src_height:
+ output_height = target_size
+ output_width = int(target_size * (src_width / src_height))
+ else:
+ output_width = target_size
+ output_height = int(target_size * (src_height / src_width))
+ initial_res = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"),
+ response_model=TaskCreateResponse,
+ data=VideoEnhanceTaskCreateRequest(
+ video_url=await upload_video_to_comfyapi(cls, video),
+ resolution=[output_width, output_height],
+ original_resolution=[src_width, src_height],
+ model_name=VIDEO_MODELS_MODELS_MAP[model["model"]],
+ ),
+ wait_label="Creating task",
+ final_label_on_success="Task created",
+ )
+ request_price = initial_res.data.consume_coins / 1000
+ if initial_res.code != 200:
+ raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}")
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"),
+ data=TaskStatusPollRequest(job_id=initial_res.data.job_id),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda x: x.data.status,
+ price_extractor=lambda x: request_price,
+ poll_interval=10.0,
+ max_poll_attempts=320,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))
+
+
+class HitPawExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ HitPawGeneralImageEnhance,
+ HitPawVideoEnhance,
+ ]
+
+
+async def comfy_entrypoint() -> HitPawExtension:
+ return HitPawExtension()
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index b3a736643..813a7c809 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,5 +1,3 @@
-import os
-
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
@@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_to_bytesio,
+ download_url_to_file_3d,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@@ -22,14 +20,13 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
validate_string,
)
-from folder_paths import get_output_directory
-def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D:
+def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
for i in response_objs:
- if i.Type.lower() == "glb":
+ if i.Type.lower() == file_type.lower():
return i
- raise ValueError("No GLB file found in response. Please report this to the developers.")
+ return None
class TencentTextToModelNode(IO.ComfyNode):
@@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode):
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ task_id = response.JobId
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
- data=To3DProTaskQueryRequest(JobId=response.JobId),
+ data=To3DProTaskQueryRequest(JobId=task_id),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- model_file = f"hunyuan_model_{response.JobId}.glb"
- await download_url_to_bytesio(
- get_glb_obj_from_response(result.ResultFile3Ds).Url,
- os.path.join(get_output_directory(), model_file),
+ glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
+ obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
+ file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
+ return IO.NodeOutput(
+ file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
)
- return IO.NodeOutput(model_file)
class TencentImageToModelNode(IO.ComfyNode):
@@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode):
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ task_id = response.JobId
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"),
- data=To3DProTaskQueryRequest(JobId=response.JobId),
+ data=To3DProTaskQueryRequest(JobId=task_id),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- model_file = f"hunyuan_model_{response.JobId}.glb"
- await download_url_to_bytesio(
- get_glb_obj_from_response(result.ResultFile3Ds).Url,
- os.path.join(get_output_directory(), model_file),
+ glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
+ obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
+ file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
+ return IO.NodeOutput(
+ file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
)
- return IO.NodeOutput(model_file)
class TencentHunyuan3DExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py
index 740607983..65f6f0d2d 100644
--- a/comfy_api_nodes/nodes_meshy.py
+++ b/comfy_api_nodes/nodes_meshy.py
@@ -1,5 +1,3 @@
-import os
-
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
@@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_to_bytesio,
+ download_url_to_file_3d,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_string,
)
-from folder_paths import get_output_directory
class MeshyTextToModelNode(IO.ComfyNode):
@@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyRefineNode(IO.ComfyNode):
@@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode):
ai_model=model,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyImageToModelNode(IO.ComfyNode):
@@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyMultiImageToModelNode(IO.ComfyNode):
@@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
seed=seed,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyRigModelNode(IO.ComfyNode):
@@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode):
texture_image_url=texture_image_url,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"),
response_model=MeshyRiggedResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(
- result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id),
)
- return IO.NodeOutput(model_file, response.result)
class MeshyAnimateModelNode(IO.ComfyNode):
@@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode):
action_id=action_id,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"),
response_model=MeshyAnimationResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id),
+ )
class MeshyTextureNode(IO.ComfyNode):
@@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode):
image_style_url=image_style_url,
),
)
+ task_id = response.result
result = await poll_op(
cls,
- ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"),
+ ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"),
response_model=MeshyModelResult,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
- model_file = f"meshy_model_{response.result}.glb"
- await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file))
- return IO.NodeOutput(model_file, response.result)
+ return IO.NodeOutput(
+ f"{task_id}.glb",
+ task_id,
+ await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id),
+ await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id),
+ )
class MeshyExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py
index 3ffdc8b90..f9cff121f 100644
--- a/comfy_api_nodes/nodes_rodin.py
+++ b/comfy_api_nodes/nodes_rodin.py
@@ -10,7 +10,6 @@ import folder_paths as comfy_paths
import os
import logging
import math
-from typing import Optional
from io import BytesIO
from typing_extensions import override
from PIL import Image
@@ -28,8 +27,9 @@ from comfy_api_nodes.util import (
poll_op,
ApiEndpoint,
download_url_to_bytesio,
+ download_url_to_file_3d,
)
-from comfy_api.latest import ComfyExtension, IO
+from comfy_api.latest import ComfyExtension, IO, Types
COMMON_PARAMETERS = [
@@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE"
return "Generating"
-def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
+def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None:
if not response.jobs:
return None
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
@@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D
)
-async def download_files(url_list, task_uuid: str):
+async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]:
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True)
model_file_path = None
+ file_3d = None
+
for i in url_list.list:
file_path = os.path.join(save_path, i.name)
- if file_path.endswith(".glb"):
+ if i.name.lower().endswith(".glb"):
model_file_path = os.path.join(result_folder_name, i.name)
- await download_url_to_bytesio(i.url, file_path)
- return model_file_path
+ file_3d = await download_url_to_file_3d(i.url, "glb")
+ # Save to disk for backward compatibility
+ with open(file_path, "wb") as f:
+ f.write(file_3d.get_bytes())
+ else:
+ await download_url_to_bytesio(i.url, file_path)
+
+ return model_file_path, file_3d
class Rodin3D_Regular(IO.ComfyNode):
@@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Detail(IO.ComfyNode):
@@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Smooth(IO.ComfyNode):
@@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode):
IO.Image.Input("Images"),
*COMMON_PARAMETERS,
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Sketch(IO.ComfyNode):
@@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode):
optional=True,
),
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3D_Gen2(IO.ComfyNode):
@@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode):
),
IO.Boolean.Input("TAPose", default=False),
],
- outputs=[IO.String.Output(display_name="3D Model Path")],
+ outputs=[
+ IO.String.Output(display_name="3D Model Path"), # for backward compatibility only
+ IO.File3DGLB.Output(display_name="GLB"),
+ ],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
@@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
)
await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, cls)
- model = await download_files(download_list, task_uuid)
+ model_path, file_3d = await download_files(download_list, task_uuid)
- return IO.NodeOutput(model)
+ return IO.NodeOutput(model_path, file_3d)
class Rodin3DExtension(ComfyExtension):
diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py
index 5abf27b4d..67c7f59fc 100644
--- a/comfy_api_nodes/nodes_tripo.py
+++ b/comfy_api_nodes/nodes_tripo.py
@@ -1,10 +1,6 @@
-import os
-from typing import Optional
-
-import torch
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension
+from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.tripo import (
TripoAnimateRetargetRequest,
TripoAnimateRigRequest,
@@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import (
)
from comfy_api_nodes.util import (
ApiEndpoint,
- download_url_as_bytesio,
+ download_url_to_file_3d,
poll_op,
sync_op,
upload_images_to_comfyapi,
)
-from folder_paths import get_output_directory
def get_model_url_from_response(response: TripoTaskResponse) -> str:
@@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str:
async def poll_until_finished(
node_cls: type[IO.ComfyNode],
response: TripoTaskResponse,
- average_duration: Optional[int] = None,
+ average_duration: int | None = None,
) -> IO.NodeOutput:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
@@ -69,12 +64,8 @@ async def poll_until_finished(
)
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
- bytesio = await download_url_as_bytesio(url)
- # Save the downloaded model file
- model_file = f"tripo_model_{task_id}.glb"
- with open(os.path.join(get_output_directory(), model_file), "wb") as f:
- f.write(bytesio.getvalue())
- return IO.NodeOutput(model_file, task_id)
+ file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id)
+ return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb)
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
@@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
- negative_prompt: Optional[str] = None,
+ negative_prompt: str | None = None,
model_version=None,
- style: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- image_seed: Optional[int] = None,
- model_seed: Optional[int] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ style: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ image_seed: int | None = None,
+ model_seed: int | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if not prompt:
@@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
- model_version: Optional[str] = None,
- style: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- model_seed: Optional[int] = None,
+ image: Input.Image,
+ model_version: str | None = None,
+ style: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ model_seed: int | None = None,
orientation=None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ texture_alignment: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
style_enum = None if style == "None" else style
if image is None:
@@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- image: torch.Tensor,
- image_left: Optional[torch.Tensor] = None,
- image_back: Optional[torch.Tensor] = None,
- image_right: Optional[torch.Tensor] = None,
- model_version: Optional[str] = None,
- orientation: Optional[str] = None,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- model_seed: Optional[int] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- geometry_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
- face_limit: Optional[int] = None,
- quad: Optional[bool] = None,
+ image: Input.Image,
+ image_left: Input.Image | None = None,
+ image_back: Input.Image | None = None,
+ image_right: Input.Image | None = None,
+ model_version: str | None = None,
+ orientation: str | None = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ model_seed: int | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ geometry_quality: str | None = None,
+ texture_alignment: str | None = None,
+ face_limit: int | None = None,
+ quad: bool | None = None,
) -> IO.NodeOutput:
if image is None:
raise RuntimeError("front image for multiview is required")
@@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode):
async def execute(
cls,
model_task_id,
- texture: Optional[bool] = None,
- pbr: Optional[bool] = None,
- texture_seed: Optional[int] = None,
- texture_quality: Optional[str] = None,
- texture_alignment: Optional[str] = None,
+ texture: bool | None = None,
+ pbr: bool | None = None,
+ texture_seed: int | None = None,
+ texture_quality: str | None = None,
+ texture_alignment: str | None = None,
) -> IO.NodeOutput:
response = await sync_op(
cls,
@@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode):
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode):
category="api node/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode):
),
],
outputs=[
- IO.String.Output(display_name="model_file"),
+ IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"),
+ IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index c3c9ff4bf..18b020eef 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -28,6 +28,7 @@ from .conversions import (
from .download_helpers import (
download_url_as_bytesio,
download_url_to_bytesio,
+ download_url_to_file_3d,
download_url_to_image_tensor,
download_url_to_video_output,
)
@@ -69,6 +70,7 @@ __all__ = [
# Download helpers
"download_url_as_bytesio",
"download_url_to_bytesio",
+ "download_url_to_file_3d",
"download_url_to_image_tensor",
"download_url_to_video_output",
# Conversions
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 4668d14a9..78bcf1fa1 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -11,7 +11,8 @@ import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.latest import IO as COMFY_IO
-from comfy_api.latest import InputImpl
+from comfy_api.latest import InputImpl, Types
+from folder_paths import get_output_directory
from . import request_logger
from ._helpers import (
@@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str:
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
+
+
+async def download_url_to_file_3d(
+ url: str,
+ file_format: str,
+ *,
+ task_id: str | None = None,
+ timeout: float | None = None,
+ max_retries: int = 5,
+ cls: type[COMFY_IO.ComfyNode] = None,
+) -> Types.File3D:
+ """Downloads a 3D model file from a URL into memory as BytesIO.
+
+ If task_id is provided, also writes the file to disk in the output directory
+ for backward compatibility with the old save-to-disk behavior.
+ """
+ file_format = file_format.lstrip(".").lower()
+ data = BytesIO()
+ await download_url_to_bytesio(
+ url,
+ data,
+ timeout=timeout,
+ max_retries=max_retries,
+ cls=cls,
+ )
+
+ if task_id is not None:
+ # This is only for backward compatability with current behavior when every 3D node is output node
+ # All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results
+ output_dir = Path(get_output_directory())
+ output_path = output_dir / f"{task_id}.{file_format}"
+ output_path.write_bytes(data.getvalue())
+ data.seek(0)
+
+ return Types.File3D(source=data, file_format=file_format)
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 3153f2b98..83d936ce1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -94,7 +94,7 @@ async def upload_image_to_comfyapi(
*,
mime_type: str | None = None,
wait_label: str | None = "Uploading",
- total_pixels: int = 2048 * 2048,
+ total_pixels: int | None = 2048 * 2048,
) -> str:
"""Uploads a single image to ComfyUI API and returns its download URL."""
return (
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
index 1409233c9..376584e5c 100644
--- a/comfy_extras/nodes_ace.py
+++ b/comfy_extras/nodes_ace.py
@@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode):
conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
return io.NodeOutput(conditioning)
+class TextEncodeAceStepAudio15(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="TextEncodeAceStepAudio1.5",
+ category="conditioning",
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("tags", multiline=True, dynamic_prompts=True),
+ io.String.Input("lyrics", multiline=True, dynamic_prompts=True),
+ io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
+ io.Int.Input("bpm", default=120, min=10, max=300),
+ io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
+ io.Combo.Input("timesignature", options=['2', '3', '4', '6']),
+ io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]),
+ io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
+ ],
+ outputs=[io.Conditioning.Output()],
+ )
+
+ @classmethod
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput:
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed)
+ conditioning = clip.encode_from_tokens_scheduled(tokens)
+ return io.NodeOutput(conditioning)
+
class EmptyAceStepLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyAceStepLatentAudio",
+ display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
inputs=[
io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
@@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode):
return io.NodeOutput({"samples": latent, "type": "audio"})
+class EmptyAceStep15LatentAudio(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="EmptyAceStep1.5LatentAudio",
+ display_name="Empty Ace Step 1.5 Latent Audio",
+ category="latent/audio",
+ inputs=[
+ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
+ io.Int.Input(
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
+ ),
+ ],
+ outputs=[io.Latent.Output()],
+ )
+
+ @classmethod
+ def execute(cls, seconds, batch_size) -> io.NodeOutput:
+ length = round((seconds * 48000 / 1920))
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
+ return io.NodeOutput({"samples": latent, "type": "audio"})
+
+class ReferenceTimbreAudio(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="ReferenceTimbreAudio",
+ category="advanced/conditioning/audio",
+ is_experimental=True,
+ description="This node sets the reference audio for timbre (for ace step 1.5)",
+ inputs=[
+ io.Conditioning.Input("conditioning"),
+ io.Latent.Input("latent", optional=True),
+ ],
+ outputs=[
+ io.Conditioning.Output(),
+ ]
+ )
+
+ @classmethod
+ def execute(cls, conditioning, latent=None) -> io.NodeOutput:
+ if latent is not None:
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True)
+ return io.NodeOutput(conditioning)
+
class AceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeAceStepAudio,
EmptyAceStepLatentAudio,
+ TextEncodeAceStepAudio15,
+ EmptyAceStep15LatentAudio,
+ ReferenceTimbreAudio,
]
async def comfy_entrypoint() -> AceExtension:
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 271b75fbd..bef723dce 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode):
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
- if 44100 != sample_rate:
- waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
+ vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
+ if vae_sample_rate != sample_rate:
+ waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
- return IO.NodeOutput({"samples":t})
+ return IO.NodeOutput({"samples": t})
encode = execute # TODO: remove
@@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode):
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
- return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
+ vae_sample_rate = getattr(vae, "audio_sample_rate", 44100)
+ return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]})
decode = execute # TODO: remove
diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py
index 5bb5df48e..eda1639ab 100644
--- a/comfy_extras/nodes_hunyuan3d.py
+++ b/comfy_extras/nodes_hunyuan3d.py
@@ -622,14 +622,20 @@ class SaveGLB(IO.ComfyNode):
category="3d",
is_output_node=True,
inputs=[
- IO.Mesh.Input("mesh"),
+ IO.MultiType.Input(
+ IO.Mesh.Input("mesh"),
+ types=[
+ IO.File3DGLB,
+ ],
+ tooltip="Mesh or GLB file to save",
+ ),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
)
@classmethod
- def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
+ def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
results = []
@@ -641,15 +647,26 @@ class SaveGLB(IO.ComfyNode):
for x in cls.hidden.extra_pnginfo:
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
- for i in range(mesh.vertices.shape[0]):
+ if isinstance(mesh, Types.File3D):
+ # Handle File3D input - save BytesIO data to output folder
f = f"{filename}_{counter:05}_.glb"
- save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
+ mesh.save_to(os.path.join(full_output_folder, f))
results.append({
"filename": f,
"subfolder": subfolder,
"type": "output"
})
- counter += 1
+ else:
+ # Handle Mesh input - save vertices and faces as GLB
+ for i in range(mesh.vertices.shape[0]):
+ f = f"{filename}_{counter:05}_.glb"
+ save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
+ results.append({
+ "filename": f,
+ "subfolder": subfolder,
+ "type": "output"
+ })
+ counter += 1
return IO.NodeOutput(ui={"3d": results})
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 4b8d950ae..f29510488 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -1,9 +1,10 @@
import nodes
import folder_paths
import os
+import uuid
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
+from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types
from pathlib import Path
@@ -81,7 +82,19 @@ class Preview3D(IO.ComfyNode):
is_experimental=True,
is_output_node=True,
inputs=[
- IO.String.Input("model_file", default="", multiline=False),
+ IO.MultiType.Input(
+ IO.String.Input("model_file", default="", multiline=False),
+ types=[
+ IO.File3DGLB,
+ IO.File3DGLTF,
+ IO.File3DFBX,
+ IO.File3DOBJ,
+ IO.File3DSTL,
+ IO.File3DUSDZ,
+ IO.File3DAny,
+ ],
+ tooltip="3D model file or path string",
+ ),
IO.Load3DCamera.Input("camera_info", optional=True),
IO.Image.Input("bg_image", optional=True),
],
@@ -89,10 +102,15 @@ class Preview3D(IO.ComfyNode):
)
@classmethod
- def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
+ def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput:
+ if isinstance(model_file, Types.File3D):
+ filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}"
+ model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
+ else:
+ filename = model_file
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
- return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
+ return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image))
process = execute # TODO: remove
diff --git a/comfyui_version.py b/comfyui_version.py
index b1ebaa115..2e2c12ced 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.11.1"
+__version__ = "0.12.1"
diff --git a/nodes.py b/nodes.py
index 1cb43d9e2..e11a8ed80 100644
--- a/nodes.py
+++ b/nodes.py
@@ -1001,7 +1001,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ),
+ "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
diff --git a/pyproject.toml b/pyproject.toml
index 042f124e4..c21ee03f1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.11.1"
+version = "0.12.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
diff --git a/requirements.txt b/requirements.txt
index 3ca417dd8..0c401873a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.37.11
-comfyui-workflow-templates==0.8.27
+comfyui-workflow-templates==0.8.31
comfyui-embedded-docs==0.4.0
torch
torchsde