From aa6f7a83bb313f121ee032a8b21aa72ea32cfee1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 31 Jan 2026 17:05:11 -0800 Subject: [PATCH 01/32] Send is_input_list on v1 and v3 schema to frontend (#12188) --- comfy_api/latest/_io.py | 2 ++ server.py | 1 + 2 files changed, 3 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 78f77d4b2..eeea9781a 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1248,6 +1248,7 @@ class Hidden(str, Enum): class NodeInfoV1: input: dict=None input_order: dict[str, list[str]]=None + is_input_list: bool=None output: list[str]=None output_is_list: list[bool]=None output_name: list[str]=None @@ -1474,6 +1475,7 @@ class Schema: info = NodeInfoV1( input=input, input_order={key: list(value.keys()) for (key, value) in input.items()}, + is_input_list=self.is_input_list, output=output, output_is_list=output_is_list, output_name=output_name, diff --git a/server.py b/server.py index 2aee5cc06..2300393b2 100644 --- a/server.py +++ b/server.py @@ -656,6 +656,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False) info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] From 873de5f37a1f4c39c0203b2989e6736b0cb12b98 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 18:11:11 -0800 Subject: [PATCH 02/32] KV cache implementation for using llama models for text generation. (#12195) --- comfy/text_encoders/llama.py | 91 +++++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3080a3e09..68ac1e804 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, Tuple import math from comfy.ldm.modules.attention import optimized_attention_for_device @@ -32,6 +32,7 @@ class Llama2Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Mistral3Small24BConfig: @@ -54,6 +55,7 @@ class Mistral3Small24BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_3BConfig: @@ -76,6 +78,7 @@ class Qwen25_3BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_06BConfig: @@ -98,6 +101,7 @@ class Qwen3_06BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_4BConfig: @@ -120,6 +124,7 @@ class Qwen3_4BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_8BConfig: @@ -142,6 +147,7 @@ class Qwen3_8BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Ovis25_2BConfig: @@ -164,6 +170,7 @@ class Ovis25_2BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_7BVLI_Config: @@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma2_2B_Config: @@ -209,6 +217,7 @@ class Gemma2_2B_Config: sliding_attention = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_4B_Config: @@ -232,6 +241,7 @@ class Gemma3_4B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_12B_Config: @@ -255,6 +265,7 @@ class Gemma3_12B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} mm_tokens_per_image = 256 @@ -356,6 +367,7 @@ class Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) @@ -373,11 +385,30 @@ class Attention(nn.Module): xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) + return self.o_proj(output), present_key_value class MLP(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -408,15 +439,17 @@ class TransformerBlock(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = residual + x @@ -426,7 +459,7 @@ class TransformerBlock(nn.Module): x = self.mlp(x) x = residual + x - return x + return x, present_key_value class TransformerBlockGemma2(nn.Module): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): @@ -451,6 +484,7 @@ class TransformerBlockGemma2(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if self.transformer_type == 'gemma3': if self.sliding_attention: @@ -468,11 +502,12 @@ class TransformerBlockGemma2(nn.Module): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = self.post_attention_layernorm(x) @@ -485,7 +520,7 @@ class TransformerBlockGemma2(nn.Module): x = self.post_feedforward_layernorm(x) x = residual + x - return x + return x, present_key_value class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): @@ -516,9 +551,10 @@ class Llama2_(nn.Module): else: self.norm = None - # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + if config.lm_head: + self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): if embeds is not None: x = embeds else: @@ -527,8 +563,13 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = past_key_values[0][2] + if position_ids is None: - position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, @@ -539,14 +580,16 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None @@ -562,16 +605,27 @@ class Llama2_(nn.Module): elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output + next_key_values = [] for i, layer in enumerate(self.layers): if all_intermediate is not None: if only_layers is None or (i in only_layers): all_intermediate.append(x.unsqueeze(1).clone()) - x = layer( + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( x=x, attention_mask=mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_kv, ) + + if current_kv is not None: + next_key_values.append(current_kv) + if i == intermediate_output: intermediate = x.clone() @@ -588,7 +642,10 @@ class Llama2_(nn.Module): if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) - return x, intermediate + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate class Gemma3MultiModalProjector(torch.nn.Module): From f8acd9c402f41efc2c748503e4c5d9e0027965ab Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:01:11 -0800 Subject: [PATCH 03/32] Reduce RAM usage, fix VRAM OOMs, and fix Windows shared memory spilling with adaptive model loading (#11845) --- comfy/audio_encoders/audio_encoders.py | 4 +- comfy/cli_args.py | 4 + comfy/clip_vision.py | 4 +- comfy/controlnet.py | 2 +- comfy/k_diffusion/sampling.py | 33 ++- comfy/ldm/hunyuan_video/upsampler.py | 4 +- comfy/memory_management.py | 81 +++++++ comfy/model_base.py | 15 +- comfy/model_management.py | 170 ++++++++++++-- comfy/model_patcher.py | 313 +++++++++++++++++++++++-- comfy/ops.py | 221 +++++++++++++++-- comfy/pinned_memory.py | 30 +++ comfy/samplers.py | 3 +- comfy/sd.py | 59 +++-- comfy/sd1_clip.py | 2 +- comfy/text_encoders/lt.py | 2 +- comfy/utils.py | 80 ++++++- comfy/windows.py | 52 ++++ comfy_extras/nodes_model_patch.py | 6 +- cuda_malloc.py | 12 +- execution.py | 16 +- main.py | 30 ++- requirements.txt | 1 + 23 files changed, 1030 insertions(+), 114 deletions(-) create mode 100644 comfy/memory_management.py create mode 100644 comfy/pinned_memory.py create mode 100644 comfy/windows.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7..63daca861 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -257,3 +258,6 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index b28bf636c..1691fca81 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0949dee44..c0c51d51a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,11 +1,12 @@ import math +import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import trange as trange_, tqdm from . import utils from . import deis @@ -13,6 +14,36 @@ from . import sa_solver import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management + + +def trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange_(*args, **kwargs) + + pbar = trange_(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ class HunyuanVideo15SRModel(): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..858bd4cc7 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,81 @@ +import math +import torch +from typing import NamedTuple + +from comfy.quant_ops import QuantizedTensor + +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views + +aimdo_allocator = None diff --git a/comfy/model_base.py b/comfy/model_base.py index 66e52864d..85acdb66a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -149,6 +149,8 @@ class BaseModel(torch.nn.Module): self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -299,7 +301,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -307,7 +309,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -322,7 +324,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -330,10 +332,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -776,8 +775,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..758e718e8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,13 @@ import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.memory_management +import comfy.utils +import comfy.quant_ops + +import comfy_aimdo.torch +import comfy_aimdo.model_vbar class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -578,9 +585,15 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -592,7 +605,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -607,15 +620,23 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - get_free_ram() + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) + if ram_to_free > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -650,7 +671,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -676,19 +700,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: @@ -732,6 +762,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -749,6 +782,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] @@ -792,7 +830,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: return torch_dev else: return cpu_dev @@ -1051,6 +1089,53 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() + STREAM_CAST_BUFFERS.clear() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1178,53 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + if signature is not None: + weight._v_signature = signature + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1243,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1135,7 +1268,7 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: @@ -1557,8 +1690,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..57b53d8c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP - -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -212,6 +200,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -284,6 +296,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -611,14 +626,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -631,13 +646,15 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -654,7 +671,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -681,7 +698,8 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -984,6 +1002,9 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) @@ -1317,10 +1338,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device @@ -1355,7 +1376,249 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + self.partially_unload_ram(1e32) + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index e406ba7ed..c3a1825ce 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,10 +19,16 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json +import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -72,7 +78,115 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + cast_dest = None + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + else: + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_source = [ pin ] + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) @@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -149,6 +286,57 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + super().__init__(in_features, out_features, bias, device, dtype) + return + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None @@ -655,8 +843,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -666,6 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -684,7 +874,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..0650e4d1a --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,30 @@ +import torch +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) + size = comfy.memory_management.vram_aligned_size(params) + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + del module._pin + return size diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..8b9782956 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers @@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index f627f7d55..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -228,8 +228,10 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -389,8 +391,18 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): @@ -765,12 +777,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) + model_management.archive_model_dtypes(self.first_stage_model) if device is None: device = model_management.vae_device() @@ -782,7 +789,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -897,7 +915,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -971,7 +989,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None @@ -1432,7 +1450,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1520,7 +1538,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1563,7 +1582,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1655,13 +1673,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) @@ -1692,9 +1711,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d4f22120b..9ecfc9c55 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e49161964..26573fb12 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module): for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: - missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing_all.extend([f"{prefix}{k}" for k in missing]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) diff --git a/comfy/utils.py b/comfy/utils.py index d97d753e6..9e98eb176 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,9 +28,10 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json import time +import mmap MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -56,21 +57,67 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, +} + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0] @@ -1308,3 +1355,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000..213dc481d --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 82c4754a3..176e6bc2f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -267,9 +267,9 @@ class ModelPatchLoader: device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..b2182df37 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,8 +87,14 @@ if not args.cuda_malloc: except: pass +if enables_dynamic_vram() and comfy_aimdo.control.init(): + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" diff --git a/execution.py b/execution.py index 4b4f63c80..93fafc4a2 100644 --- a/execution.py +++ b/execution.py @@ -9,9 +9,11 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + comfy.model_management.reset_cast_buffers() + torch.cuda.synchronize() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) diff --git a/main.py b/main.py index 37b06c1fa..b8c951375 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import os import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger from app.assets.scanner import seed_assets import itertools @@ -173,6 +173,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + import comfy.utils import execution @@ -184,6 +185,33 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +import comfy_aimdo.control +import comfy_aimdo.torch + +if enables_dynamic_vram(): + if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + logging.info("DynamicVRAM support detected and enabled") + else: + logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) diff --git a/requirements.txt b/requirements.txt index 4ac94cb16..be0bc537e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +comfy-aimdo>=0.1.6 requests #non essential dependencies: From 32621c6a11f6ce1c0dd46aeb08e35dd64ec804f0 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 31 Jan 2026 22:13:48 -0800 Subject: [PATCH 04/32] fix: improve error message when node type is missing (#12194) - Change error type from 'invalid_prompt' to 'missing_node_type' for frontend detection - Add extra_info with node_id, class_type, and node_title (from _meta.title) - Improve user-facing message: 'Node X not found. The custom node may not be installed.' --- execution.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index 93fafc4a2..3dbab82e6 100644 --- a/execution.py +++ b/execution.py @@ -1014,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ outputs = set() for x in prompt: if 'class_type' not in prompt[x]: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title') error = { - "type": "invalid_prompt", - "message": "Cannot execute because a node is missing the class_type property.", + "type": "missing_node_type", + "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": None, + "node_title": node_title + } } return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title', class_type) error = { - "type": "invalid_prompt", - "message": f"Cannot execute because node {class_type} does not exist.", + "type": "missing_node_type", + "message": f"Node '{node_title}' not found. The custom node may not be installed.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": class_type, + "node_title": node_title + } } return (False, error, [], {}) From 667a1b8878187f7a5c8955b272e33c16e4c46bb7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:55:18 -0800 Subject: [PATCH 05/32] Fix some custom nodes breaking. (#12203) --- comfy/model_patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57b53d8c5..c8e6f088f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -111,6 +111,10 @@ def move_weight_functions(m, device): memory += f.move_to(device=device) return memory +def string_to_seed(data): + logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils") + return comfy.utils.string_to_seed(data) + class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key From 361b9a82a3e2445fc22e352df187138b0c8f67fb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 08:42:32 -0800 Subject: [PATCH 06/32] fix pinning with model defined dtype (#12208) pinned memory was converted back to pinning the CPU side weight without any changes. Fix the pinner to use the CPU weight and not the model defined geometry. This will either save RAM or stop buffer overruns when the types mismatch. Fix the model defined weight caster to use the [ s.weight, s.bias ] interpretation, as xfer_dest might be the flattened pin now. Fix the detection of needing to cast to not be conditional on !pin. --- comfy/ops.py | 22 +++++++++++----------- comfy/pinned_memory.py | 3 +-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index c3a1825ce..53c5e4dc3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -96,16 +96,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu pin = comfy.pinned_memory.get_pin(s) if pin is not None: xfer_source = [ pin ] - else: - for data, geometry in zip([ s.weight, s.bias ], cast_geometry): - if data is None: - continue - if data.dtype != geometry.dtype: - cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) - xfer_dest = None - break + + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) @@ -132,7 +132,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.model_management.sync_stream(device, offload_stream) if cast_dest is not None: - for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 0650e4d1a..8acc327a7 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -11,8 +11,7 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event - params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) - size = comfy.memory_management.vram_aligned_size(params) + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) pin = torch.empty((size,), dtype=torch.uint8) if comfy.model_management.pin_memory(pin): module._pin = pin From 794d05bdb12be95f0e58dbb1728542ac4c7998aa Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:09:21 -0800 Subject: [PATCH 07/32] dynamic_vram: respect argument cast dtypes in non-comfy weights (#12209) This function has a dtype argument that allows the caller to set the dtype in the cast. TIL Some models override this on weight casts, which means its the highest priority. Priority scheme is: argument > model dtype > state dict dtype --- comfy/model_management.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 758e718e8..6b1166b94 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1202,27 +1202,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str assert r is None assert stream is None - r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) + + if dtype is None: + dtype = weight._model_dtype + + r = torch.empty_like(weight, dtype=dtype, device=device) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] - - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + weight._v_signature = signature + #Send it over + v_tensor.copy_(weight, non_blocking=non_blocking) #always take a deep copy even if _v is good, as we have no reasonable point to unpin #a non comfy weight r.copy_(v_tensor) comfy_aimdo.model_vbar.vbar_unpin(weight._v) return r + if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: + #Offloaded casting could skip this, however it would make the quantizations + #inconsistent between loaded and offloaded weights. So force the double casting + #that would happen in regular flow to make offload deterministic. + cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_buffer.copy_(weight, non_blocking=non_blocking) + weight = cast_buffer r.copy_(weight, non_blocking=non_blocking) - if signature is not None: - weight._v_signature = signature - v_tensor.copy_(r) - comfy_aimdo.model_vbar.vbar_unpin(weight._v) - return r if device is None or weight.device == device: From 2b5da3b72ebb8017661b0d58b548d4aeb53b7e42 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:09:55 -0800 Subject: [PATCH 08/32] dynamic_vram: silence pytorch buffer warning (#12210) This is log clutter and concerning to users. Its a false alarm. --- comfy/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 9e98eb176..c1b536833 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,6 +32,7 @@ from comfy.cli_args import args, enables_dynamic_vram import json import time import mmap +import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -85,7 +86,10 @@ def load_safetensors(ckpt): header_size = struct.unpack(" Date: Sun, 1 Feb 2026 17:10:15 -0800 Subject: [PATCH 09/32] requirements: bump comfy-aimdo to 0.1.7 (#12211) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index be0bc537e..3ca417dd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.6 +comfy-aimdo>=0.1.7 requests #non essential dependencies: From 021ba2071985f9e3b2984f396a238095c4a64832 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 1 Feb 2026 17:12:52 -0800 Subject: [PATCH 10/32] Fix issue with parameters on root model object. (#12216) --- comfy/model_patcher.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8e6f088f..b70c031bf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -161,6 +161,11 @@ def get_key_weight(model, key): return weight, set_func, convert_func +def key_param_name_to_key(key, param): + if len(key) == 0: + return param + return "{}.{}".format(key, param) + class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model @@ -795,7 +800,7 @@ class ModelPatcher: continue for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) if comfy.model_management.is_device_cuda(device_to): @@ -811,7 +816,7 @@ class ModelPatcher: n = x[1] params = x[3] for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: @@ -917,7 +922,7 @@ class ModelPatcher: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: @@ -968,7 +973,7 @@ class ModelPatcher: logging.debug("freed {}".format(n)) for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) self.model.model_lowvram = True @@ -1501,7 +1506,7 @@ class ModelPatcherDynamic(ModelPatcher): def setup_param(self, m, n, param_key): nonlocal num_patches - key = "{}.{}".format(n, param_key) + key = key_param_name_to_key(n, param_key) weight_function = [] @@ -1540,7 +1545,7 @@ class ModelPatcherDynamic(ModelPatcher): else: for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) weight.seed_key = key set_dirty(weight, dirty) From dd86b155210df9b34f479d70dad675aa782a30ef Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 00:51:09 -0800 Subject: [PATCH 11/32] Enable embeddings for some qwen 3 models. (#12218) --- comfy/text_encoders/anima.py | 2 +- comfy/text_encoders/flux.py | 6 +++--- comfy/text_encoders/z_image.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index 41f95bcb6..b6f58cb25 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -8,7 +8,7 @@ import torch class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index f67a5f805..1ae398789 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} @@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class KleinTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index ad41bfb1e..33b7cf594 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -6,7 +6,7 @@ import os class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class ZImageTokenizer(sd1_clip.SD1Tokenizer): From 37f711d4a1d429e6b390b01729510525155385e1 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:34:46 -0800 Subject: [PATCH 12/32] mm: Fix cast buffers with intel offloading (#12229) Intel has offloading support but there were some nvidia calls in the new cast buffer stuff. --- comfy/model_management.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6b1166b94..2167f81bf 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1112,11 +1112,11 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now - torch.cuda.synchronize() + synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer #FIXME: This doesn't work in Aimdo because mempool cant clear cache - torch.cuda.empty_cache() + soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) STREAM_CAST_BUFFERS[offload_stream] = cast_buffer @@ -1132,9 +1132,7 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.empty_cache() + soft_empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1284,7 +1282,7 @@ def discard_cuda_async_error(): a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b - torch.cuda.synchronize() + synchronize() except torch.AcceleratorError: #Dump it! We already know about it from the synchronous return pass @@ -1688,6 +1686,12 @@ def lora_compute_dtype(device): LORA_COMPUTE_DTYPES[device] = dtype return dtype +def synchronize(): + if is_intel_xpu(): + torch.xpu.synchronize() + elif torch.cuda.is_available(): + torch.cuda.synchronize() + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: From de9ada6a4147f4626f903cd975a5d2c134af3915 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:35:20 -0800 Subject: [PATCH 13/32] Dynamic VRAM unloading fix (#12227) * mp: fix full dynamic unloading This was not unloading dynamic models when requesting a full unload via the unpatch() code path. This was ok, i your workflow was all dynamic models but fails with big VRAM leaks if you need to fully unload something for a regular ModelPatcher It also fices the "unload models" button. * mm: load models outside of Aimdo Mempool In dynamic_vram mode, escape the Aimdo mempool and load into the regular mempool. Use a dummy thread to do it. --- comfy/model_management.py | 29 ++++++++++++++++++++++------- comfy/model_patcher.py | 2 +- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2167f81bf..cd035f017 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -19,7 +19,8 @@ import psutil import logging from enum import Enum -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +import threading import torch import sys import platform @@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ soft_empty_cache() return unloaded_models -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): +def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_loaded_models.insert(0, loaded_model) return -def load_model_gpu(model): - return load_models_gpu([model]) +def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load): + with torch.inference_mode(): + load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + soft_empty_cache() + +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + #Deliberately load models outside of the Aimdo mempool so they can be retained accross + #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are + #thread local. So exploit that to escape context + if enables_dynamic_vram(): + t = threading.Thread( + target=load_models_gpu_thread, + args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + ) + t.start() + t.join() + else: + load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, + minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) def loaded_models(only_currently_used=False): output = [] @@ -1717,9 +1735,6 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -#TODO: might be cleaner to put this somewhere else -import threading - class InterruptProcessingException(Exception): pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b70c031bf..cdf289395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1597,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) - self.partially_unload(None) + self.partially_unload(None, 1e32) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above From c05a08ae66d8114c30f8d3606b9de3117cb61ef8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 16:52:07 -0800 Subject: [PATCH 14/32] Add back function. (#12234) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index cd035f017..72348258b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -767,6 +767,9 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) +def load_model_gpu(model): + return load_models_gpu([model]) + def loaded_models(only_currently_used=False): output = [] for m in current_loaded_models: From ba5bf3f1a801c6d4c24f015215314b6a653f8752 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 3 Feb 2026 05:17:59 +0200 Subject: [PATCH 15/32] [API Nodes] HitPaw API nodes (#12117) * feat(api-nodes): add HitPaw API nodes * remove face_soft_2x model as not working --------- Co-authored-by: Robin Huang --- comfy_api_nodes/apis/hitpaw.py | 51 ++++ comfy_api_nodes/nodes_hitpaw.py | 342 +++++++++++++++++++++++++ comfy_api_nodes/util/upload_helpers.py | 2 +- 3 files changed, 394 insertions(+), 1 deletion(-) create mode 100644 comfy_api_nodes/apis/hitpaw.py create mode 100644 comfy_api_nodes/nodes_hitpaw.py diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py new file mode 100644 index 000000000..b23c5d9eb --- /dev/null +++ b/comfy_api_nodes/apis/hitpaw.py @@ -0,0 +1,51 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputVideoModel(TypedDict): + model: str + resolution: str + + +class ImageEnhanceTaskCreateRequest(BaseModel): + model_name: str = Field(...) + img_url: str = Field(...) + extension: str = Field(".png") + exif: bool = Field(False) + DPI: int | None = Field(None) + + +class VideoEnhanceTaskCreateRequest(BaseModel): + video_url: str = Field(...) + extension: str = Field(".mp4") + model_name: str | None = Field(...) + resolution: list[int] = Field(..., description="Target resolution [width, height]") + original_resolution: list[int] = Field(..., description="Original video resolution [width, height]") + + +class TaskCreateDataResponse(BaseModel): + job_id: str = Field(...) + consume_coins: int | None = Field(None) + + +class TaskStatusPollRequest(BaseModel): + job_id: str = Field(...) + + +class TaskCreateResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreateDataResponse | None = Field(None) + + +class TaskStatusDataResponse(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + res_url: str = Field("") + + +class TaskStatusResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskStatusDataResponse = Field(...) diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py new file mode 100644 index 000000000..488080a74 --- /dev/null +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -0,0 +1,342 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hitpaw import ( + ImageEnhanceTaskCreateRequest, + InputVideoModel, + TaskCreateDataResponse, + TaskCreateResponse, + TaskStatusPollRequest, + TaskStatusResponse, + VideoEnhanceTaskCreateRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + downscale_image_tensor, + get_image_dimensions, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_video_duration, +) + +VIDEO_MODELS_MODELS_MAP = { + "Portrait Restore Model (1x)": "portrait_restore_1x", + "Portrait Restore Model (2x)": "portrait_restore_2x", + "General Restore Model (1x)": "general_restore_1x", + "General Restore Model (2x)": "general_restore_2x", + "General Restore Model (4x)": "general_restore_4x", + "Ultra HD Model (2x)": "ultrahd_restore_2x", + "Generative Model (1x)": "generative_1x", +} + +# Resolution name to target dimension (shorter side) in pixels +RESOLUTION_TARGET_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2160, + "8K": 4320, +} + +# Square (1:1) resolutions use standard square dimensions +RESOLUTION_SQUARE_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2048, # DCI 4K square + "8K": 4096, # DCI 8K square +} + +# Models with limited resolution support (no 8K) +LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"} + +# Resolution options for different model types +RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"] +RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"] + +# Maximum output resolution in pixels +MAX_PIXELS_GENERATIVE = 32_000_000 +MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000 + + +class HitPawGeneralImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HitPawGeneralImageEnhance", + display_name="HitPaw General Image Enhance", + category="api node/image/HitPaw", + description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " + f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", + inputs=[ + IO.Combo.Input("model", options=["generative_portrait", "generative"]), + IO.Image.Input("image"), + IO.Combo.Input("upscale_factor", options=[1, 2, 4]), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed the limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := { + "generative_portrait": {"min": 0.02, "max": 0.06}, + "generative": {"min": 0.05, "max": 0.15} + }; + $price := $lookup($prices, widgets.model); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + upscale_factor: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + height, width = get_image_dimensions(image) + requested_scale = upscale_factor + output_pixels = height * width * requested_scale * requested_scale + if output_pixels > MAX_PIXELS_GENERATIVE: + if auto_downscale: + input_pixels = width * height + scale = 1 + max_input_pixels = MAX_PIXELS_GENERATIVE + + for candidate in [4, 2, 1]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= MAX_PIXELS_GENERATIVE: + scale = candidate + max_input_pixels = None + break + # Check if we can downscale input by at most 2x to fit + downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + upscale_factor = scale + else: + output_width = width * requested_scale + output_height = height * requested_scale + raise ValueError( + f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). " + f"Enable auto_downscale or use a smaller input image or a lower upscale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=ImageEnhanceTaskCreateRequest( + model_name=f"{model}_{upscale_factor}x", + img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + request_price = initial_res.data.consume_coins / 1000 + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskCreateDataResponse(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) + + +class HitPawVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + model_options = [] + for model_name in VIDEO_MODELS_MODELS_MAP: + if model_name in LIMITED_RESOLUTION_MODELS: + resolutions = RESOLUTIONS_LIMITED + else: + resolutions = RESOLUTIONS_FULL + model_options.append( + IO.DynamicCombo.Option( + model_name, + [IO.Combo.Input("resolution", options=resolutions)], + ) + ) + + return IO.Schema( + node_id="HitPawVideoEnhance", + display_name="HitPaw Video Enhance", + category="api node/video/HitPaw", + description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " + "Prices shown are per second of video.", + inputs=[ + IO.DynamicCombo.Input("model", options=model_options), + IO.Video.Input("video"), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), + expr=""" + ( + $m := $lookup(widgets, "model"); + $res := $lookup(widgets, "model.resolution"); + $standard_model_prices := { + "original": {"min": 0.01, "max": 0.198}, + "720p": {"min": 0.01, "max": 0.06}, + "1080p": {"min": 0.015, "max": 0.09}, + "2k/qhd": {"min": 0.02, "max": 0.117}, + "4k/uhd": {"min": 0.025, "max": 0.152}, + "8k": {"min": 0.033, "max": 0.198} + }; + $ultra_hd_model_prices := { + "original": {"min": 0.015, "max": 0.264}, + "720p": {"min": 0.015, "max": 0.092}, + "1080p": {"min": 0.02, "max": 0.12}, + "2k/qhd": {"min": 0.026, "max": 0.156}, + "4k/uhd": {"min": 0.034, "max": 0.203}, + "8k": {"min": 0.044, "max": 0.264} + }; + $generative_model_prices := { + "original": {"min": 0.015, "max": 0.338}, + "720p": {"min": 0.008, "max": 0.090}, + "1080p": {"min": 0.05, "max": 0.15}, + "2k/qhd": {"min": 0.038, "max": 0.225}, + "4k/uhd": {"min": 0.056, "max": 0.338} + }; + $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices : + $contains($m, "generative") ? $generative_model_prices : + $standard_model_prices; + $price := $lookup($prices, $res); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max, + "format": {"approximate": true, "suffix": "/second"} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: InputVideoModel, + video: Input.Video, + ) -> IO.NodeOutput: + validate_video_duration(video, min_duration=0.5, max_duration=60 * 60) + resolution = model["resolution"] + src_width, src_height = video.get_dimensions() + + if resolution == "original": + output_width = src_width + output_height = src_height + else: + if src_width == src_height: + target_size = RESOLUTION_SQUARE_MAP[resolution] + if target_size < src_width: + raise ValueError( + f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than " + f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'." + ) + output_width = target_size + output_height = target_size + else: + min_dimension = min(src_width, src_height) + target_size = RESOLUTION_TARGET_MAP[resolution] + if target_size < min_dimension: + raise ValueError( + f"Selected resolution {resolution} ({target_size}p) is smaller than " + f"the input video's shorter dimension ({min_dimension}p). " + f"Please select a higher resolution or 'original'." + ) + if src_width > src_height: + output_height = target_size + output_width = int(target_size * (src_width / src_height)) + else: + output_width = target_size + output_height = int(target_size * (src_height / src_width)) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=VideoEnhanceTaskCreateRequest( + video_url=await upload_video_to_comfyapi(cls, video), + resolution=[output_width, output_height], + original_resolution=[src_width, src_height], + model_name=VIDEO_MODELS_MODELS_MAP[model["model"]], + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + request_price = initial_res.data.consume_coins / 1000 + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskStatusPollRequest(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) + + +class HitPawExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HitPawGeneralImageEnhance, + HitPawVideoEnhance, + ] + + +async def comfy_entrypoint() -> HitPawExtension: + return HitPawExtension() diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 3153f2b98..83d936ce1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -94,7 +94,7 @@ async def upload_image_to_comfyapi( *, mime_type: str | None = None, wait_label: str | None = "Uploading", - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, ) -> str: """Uploads a single image to ComfyUI API and returns its download URL.""" return ( From 3c1a1a2df82fc6c7aa20a3c8301d1c632e1a1d87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 2 Feb 2026 21:06:18 -0800 Subject: [PATCH 16/32] Basic support for the ace step 1.5 model. (#12237) --- comfy/latent_formats.py | 4 + comfy/ldm/ace/ace_step15.py | 1093 ++++++++++++++++++++++++++++++++++ comfy/model_base.py | 42 ++ comfy/model_detection.py | 5 + comfy/sd.py | 21 +- comfy/sd1_clip.py | 2 + comfy/supported_models.py | 35 +- comfy/text_encoders/ace15.py | 218 +++++++ comfy/text_encoders/llama.py | 67 +++ comfy_extras/nodes_ace.py | 75 +++ comfy_extras/nodes_audio.py | 10 +- nodes.py | 2 +- 12 files changed, 1566 insertions(+), 8 deletions(-) create mode 100644 comfy/ldm/ace/ace_step15.py create mode 100644 comfy/text_encoders/ace15.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4b3a3798c..f59999af6 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -755,6 +755,10 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class ACEAudio15(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + class ChromaRadiance(LatentFormat): latent_channels = 3 spacial_downscale_ratio = 1 diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py new file mode 100644 index 000000000..d90549658 --- /dev/null +++ b/comfy/ldm/ace/ace_step15.py @@ -0,0 +1,1093 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import itertools +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +from comfy.ldm.flux.layers import timestep_embedding + +def get_layer_class(operations, layer_name): + if operations is not None and hasattr(operations, layer_name): + return getattr(operations, layer_name) + return getattr(nn, layer_name) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, x.device, x.dtype) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device), + self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device), + ) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device) + self.act1 = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device) + self.in_channels = in_channels + self.act2 = nn.SiLU() + self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device) + self.scale = scale + + def forward(self, t, dtype=None): + t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale) + temb = self.linear_1(t_freq.to(dtype=dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) + return temb, timestep_proj + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rms_norm_eps=1e-6, + is_cross_attention=False, + sliding_window=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.is_cross_attention = is_cross_attention + self.sliding_window = sliding_window + + Linear = get_layer_class(operations, "Linear") + + self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device) + self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + position_embeddings=None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + query_states = query_states.transpose(1, 2) + + if self.is_cross_attention and encoder_hidden_states is not None: + bsz_enc, kv_len, _ = encoder_hidden_states.size() + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + + key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + else: + kv_len = q_len + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + n_rep = self.num_heads // self.num_kv_heads + if n_rep > 1: + key_states = key_states.repeat_interleave(n_rep, dim=1) + value_states = value_states.repeat_interleave(n_rep, dim=1) + + attn_bias = None + if self.sliding_window is not None and not self.is_cross_attention: + indices = torch.arange(q_len, device=query_states.device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + in_window = torch.abs(diff) <= self.sliding_window + + window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype) + min_value = torch.finfo(query_states.dtype).min + window_bias.masked_fill_(~in_window, min_value) + + window_bias = window_bias.unsqueeze(0).unsqueeze(0) + + if attn_bias is not None: + if attn_bias.dtype == torch.bool: + base_bias = torch.zeros_like(window_bias) + base_bias.masked_fill_(~attn_bias, min_value) + attn_bias = base_bias + window_bias + else: + attn_bias = attn_bias + window_bias + else: + attn_bias = window_bias + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True) + attn_output = self.o_proj(attn_output) + + return attn_output + +class AceStepDiTLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + layer_type="full_attention", + sliding_window=128, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self_attn_window = sliding_window if layer_type == "sliding_attention" else None + + self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, sliding_window=self_attn_window, + dtype=dtype, device=device, operations=operations + ) + + self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.cross_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=True, dtype=dtype, device=device, operations=operations + ) + + self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + temb, + encoder_hidden_states, + position_embeddings, + attention_mask=None, + encoder_attention_mask=None + ): + modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1) + + norm_hidden = self.self_attn_norm(hidden_states) + norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa + + attn_out = self.self_attn( + norm_hidden, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = hidden_states + attn_out * gate_msa + + norm_hidden = self.cross_attn_norm(hidden_states) + attn_out = self.cross_attn( + norm_hidden, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask + ) + hidden_states = hidden_states + attn_out + + norm_hidden = self.mlp_norm(hidden_states) + norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa + + mlp_out = self.mlp(norm_hidden) + hidden_states = hidden_states + mlp_out * c_gate_msa + + return hidden_states + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, dtype=dtype, device=device, operations=operations + ) + self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class AceStepLyricEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, inputs_embeds, attention_mask=None): + hidden_states = self.embed_tokens(inputs_embeds) + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + position_embeddings = (cos, sin) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + + hidden_states = self.norm(hidden_states) + return hidden_states + +class AceStepTimbreEncoder(nn.Module): + def __init__( + self, + timbre_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + + def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + B = N + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + + sorted_indices = torch.argsort( + refer_audio_order_mask * N + torch.arange(N, device=device), + stable=True + ) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype) + + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d) + + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.view(B, max_count) + return timbre_embs_unpack, new_mask + + def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None): + hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=(cos, sin), + attention_mask=attention_mask + ) + hidden_states = self.norm(hidden_states) + + flat_states = hidden_states[:, 0, :] + unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask) + return unpacked_embs, unpacked_mask + + +def pack_sequences(hidden1, hidden2, mask1, mask2): + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + B, L, D = hidden_cat.shape + + if mask1 is not None and mask2 is not None: + mask_cat = torch.cat([mask1, mask2], dim=1) + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D) + hidden_sorted = torch.gather(hidden_cat, 1, gather_idx) + lengths = mask_cat.sum(dim=1) + new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + else: + new_mask = None + hidden_sorted = hidden_cat + + return hidden_sorted, new_mask + +class AceStepConditionEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + timbre_hidden_dim, + hidden_size, + num_lyric_layers, + num_timbre_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.lyric_encoder = AceStepLyricEncoder( + text_hidden_dim=text_hidden_dim, + hidden_size=hidden_size, + num_layers=num_lyric_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + self.timbre_encoder = AceStepTimbreEncoder( + timbre_hidden_dim=timbre_hidden_dim, + hidden_size=hidden_size, + num_layers=num_timbre_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + def forward( + self, + text_hidden_states=None, + text_attention_mask=None, + lyric_hidden_states=None, + lyric_attention_mask=None, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None + ): + text_emb = self.text_projector(text_hidden_states) + + lyric_emb = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, + attention_mask=lyric_attention_mask + ) + + timbre_emb, timbre_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask + ) + + merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask) + final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask) + + return final_emb, final_mask + +# -------------------------------------------------------------------------------- +# Main Diffusion Model (DiT) +# -------------------------------------------------------------------------------- + +class AceStepDiTModel(nn.Module): + def __init__( + self, + in_channels, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + patch_size, + audio_acoustic_hidden_dim, + layer_types=None, + sliding_window=128, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.patch_size = patch_size + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + Conv1d = get_layer_class(operations, "Conv1d") + ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d") + Linear = get_layer_class(operations, "Linear") + + self.proj_in = nn.Sequential( + nn.Identity(), + Conv1d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, + dtype=dtype, device=device)) + + self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + + if layer_types is None: + layer_types = ["full_attention"] * num_layers + + if len(layer_types) < num_layers: + layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers)) + + self.layers = nn.ModuleList([ + AceStepDiTLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + layer_type=layer_types[i], + sliding_window=sliding_window, + dtype=dtype, device=device, operations=operations + ) for i in range(num_layers) + ]) + + self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.proj_out = nn.Sequential( + nn.Identity(), + ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device) + ) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents + ): + temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype) + temb = temb_t + temb_r + timestep_proj = proj_t + proj_r + + x = torch.cat([context_latents, hidden_states], dim=-1) + original_seq_len = x.shape[1] + + pad_length = 0 + if x.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (x.shape[1] % self.patch_size) + x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0) + + x = x.transpose(1, 2) + x = self.proj_in(x) + x = x.transpose(1, 2) + + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + seq_len = x.shape[1] + cos, sin = self.rotary_emb(x, seq_len=seq_len) + + for layer in self.layers: + x = layer( + hidden_states=x, + temb=timestep_proj, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + encoder_attention_mask=None + ) + + shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + x = self.norm_out(x) * (1 + scale) + shift + + x = x.transpose(1, 2) + x = self.proj_out(x) + x = x.transpose(1, 2) + + x = x[:, :original_seq_len, :] + return x + + +class AttentionPooler(nn.Module): + def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, x): + B, T, P, D = x.shape + x = self.embed_tokens(x) + special = self.special_token.expand(B, T, 1, -1) + x = torch.cat([special, x], dim=2) + x = x.view(B * T, P + 1, D) + + cos, sin = self.rotary_emb(x, seq_len=P + 1) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + return x[:, 0, :].view(B, T, D) + + +class FSQ(nn.Module): + def __init__( + self, + levels, + dim=None, + device=None, + dtype=None, + operations=None + ): + super().__init__() + + _levels = torch.tensor(levels, dtype=torch.int32, device=device) + self.register_buffer('_levels', _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0) + self.register_buffer('_basis', _basis, persistent=False) + + self.codebook_dim = len(levels) + self.dim = dim if dim is not None else self.codebook_dim + + requires_projection = self.dim != self.codebook_dim + if requires_projection: + self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.codebook_size = self._levels.prod().item() + + indices = torch.arange(self.codebook_size, device=device) + implicit_codebook = self._indices_to_codes(indices) + + if dtype is not None: + implicit_codebook = implicit_codebook.to(dtype) + + self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) + + def bound(self, z): + levels_minus_1 = (self._levels - 1).to(z.dtype) + scale = 2. / levels_minus_1 + bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 + + zhat = bracket.floor() + bracket_ste = bracket + (zhat - bracket).detach() + + return scale * bracket_ste - 1. + + def _indices_to_codes(self, indices): + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. + + def codes_to_indices(self, zhat): + zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1)) + return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32) + + def forward(self, z): + orig_dtype = z.dtype + z = self.project_in(z) + + codes = self.bound(z) + indices = self.codes_to_indices(codes) + + out = self.project_out(codes) + return out.to(orig_dtype), indices + + +class ResidualFSQ(nn.Module): + def __init__( + self, + levels, + num_quantizers, + dim=None, + bound_hard_clamp=True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + + codebook_dim = len(levels) + dim = dim if dim is not None else codebook_dim + + requires_projection = codebook_dim != dim + if requires_projection: + self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.layers = nn.ModuleList() + levels_tensor = torch.tensor(levels, device=device) + scales = [] + + for ind in range(num_quantizers): + scale_val = levels_tensor.float() ** -ind + scales.append(scale_val) + + self.layers.append(FSQ( + levels=levels, + dim=codebook_dim, + device=device, + dtype=dtype, + operations=operations + )) + + scales_tensor = torch.stack(scales) + if dtype is not None: + scales_tensor = scales_tensor.to(dtype) + self.register_buffer('scales', scales_tensor, persistent=False) + + if bound_hard_clamp: + val = 1 + (1 / (levels_tensor.float() - 1)) + if dtype is not None: + val = val.to(dtype) + self.register_buffer('soft_clamp_input_value', val, persistent=False) + + def get_output_from_indices(self, indices, dtype=torch.float32): + if indices.dim() == 2: + indices = indices.unsqueeze(-1) + + all_codes = [] + for i, layer in enumerate(self.layers): + idx = indices[..., i].long() + codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype)) + all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype)) + + codes_summed = torch.stack(all_codes).sum(dim=0) + return self.project_out(codes_summed) + + def forward(self, x): + x = self.project_in(x) + + if hasattr(self, 'soft_clamp_input_value'): + sc_val = self.soft_clamp_input_value.to(x.dtype) + x = (x / sc_val).tanh() * sc_val + + quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) + residual = x + all_indices = [] + + for layer, scale in zip(self.layers, self.scales): + scale = scale.to(residual.dtype) + + quantized, indices = layer(residual / scale) + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + + quantized_out = self.project_out(quantized_out) + all_indices = torch.stack(all_indices, dim=-1) + + return quantized_out, all_indices + + +class AceStepAudioTokenizer(nn.Module): + def __init__( + self, + audio_acoustic_hidden_dim, + hidden_size, + pool_window_size, + fsq_dim, + fsq_levels, + fsq_input_num_quantizers, + num_layers, + head_dim, + rms_norm_eps, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device) + self.attention_pooler = AttentionPooler( + hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations + ) + self.pool_window_size = pool_window_size + self.fsq_dim = fsq_dim + self.quantizer = ResidualFSQ( + dim=fsq_dim, + levels=fsq_levels, + num_quantizers=fsq_input_num_quantizers, + bound_hard_clamp=True, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, hidden_states): + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized, indices + + def tokenize(self, x): + B, T, D = x.shape + P = self.pool_window_size + + if T % P != 0: + pad = P - (T % P) + x = F.pad(x, (0, 0, 0, pad)) + T = x.shape[1] + + T_patch = T // P + x = x.view(B, T_patch, P, D) + + quantized, indices = self.forward(x) + return quantized, indices + + +class AudioTokenDetokenizer(nn.Module): + def __init__( + self, + hidden_size, + pool_window_size, + audio_acoustic_hidden_dim, + num_layers, + head_dim, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.pool_window_size = pool_window_size + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device)) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) + self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device) + + def forward(self, x): + B, T, D = x.shape + x = self.embed_tokens(x) + x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) + x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype) + x = x.view(B * T, self.pool_window_size, D) + + cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + x = self.proj_out(x) + return x.view(B, T * self.pool_window_size, -1) + + +class AceStepConditionGenerationModel(nn.Module): + def __init__( + self, + in_channels=192, + hidden_size=2048, + text_hidden_dim=1024, + timbre_hidden_dim=64, + audio_acoustic_hidden_dim=64, + num_dit_layers=24, + num_lyric_layers=8, + num_timbre_layers=4, + num_tokenizer_layers=2, + num_heads=16, + num_kv_heads=8, + head_dim=128, + intermediate_size=6144, + patch_size=2, + pool_window_size=5, + rms_norm_eps=1e-06, + timestep_mu=-0.4, + timestep_sigma=1.0, + data_proportion=0.5, + sliding_window=128, + layer_types=None, + fsq_dim=2048, + fsq_levels=[8, 8, 8, 5, 5, 5], + fsq_input_num_quantizers=1, + audio_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + self.timestep_mu = timestep_mu + self.timestep_sigma = timestep_sigma + self.data_proportion = data_proportion + self.pool_window_size = pool_window_size + + if layer_types is None: + layer_types = [] + for i in range(num_dit_layers): + layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention") + + self.decoder = AceStepDiTModel( + in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, + intermediate_size, patch_size, audio_acoustic_hidden_dim, + layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.encoder = AceStepConditionEncoder( + text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, + num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.tokenizer = AceStepAudioTokenizer( + audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.detokenizer = AudioTokenDetokenizer( + hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, + dtype=dtype, device=device, operations=operations + ) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + def prepare_condition( + self, + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, + precomputed_lm_hints_25Hz=None, + audio_codes=None + ): + encoder_hidden, encoder_mask = self.encoder( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask + ) + + if precomputed_lm_hints_25Hz is not None: + lm_hints = precomputed_lm_hints_25Hz + else: + if audio_codes is not None: + if audio_codes.shape[1] * 5 < src_latents.shape[1]: + audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) + lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) + else: + assert False + # TODO ? + + lm_hints = self.detokenizer(lm_hints_5Hz) + + lm_hints = lm_hints[:, :src_latents.shape[1], :] + if is_covers is None: + src_latents = lm_hints + else: + src_latents = torch.where(is_covers.unsqueeze(-1).unsqueeze(-1) > 0, lm_hints, src_latents) + + context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) + + return encoder_hidden, encoder_mask, context_latents + + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, **kwargs): + text_attention_mask = None + lyric_attention_mask = None + refer_audio_order_mask = None + attention_mask = None + chunk_masks = None + is_covers = None + src_latents = None + precomputed_lm_hints_25Hz = None + lyric_hidden_states = lyric_embed + text_hidden_states = context + refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2) + + x = x.movedim(-1, -2) + + if refer_audio_order_mask is None: + refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) + + if src_latents is None and is_covers is None: + src_latents = x + + if chunk_masks is None: + chunk_masks = torch.ones_like(x) + + enc_hidden, enc_mask, context_latents = self.prepare_condition( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes + ) + + out = self.decoder(hidden_states=x, + timestep=timestep, + timestep_r=timestep, + attention_mask=attention_mask, + encoder_hidden_states=enc_hidden, + encoder_attention_mask=enc_mask, + context_latents=context_latents + ) + + return out.movedim(-1, -2) diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..89944548c 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -50,6 +50,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.ace.ace_step15 import comfy.model_management import comfy.patcher_extension @@ -1540,6 +1541,47 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out +class ACEStep15(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + device = kwargs["device"] + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_lyrics = kwargs.get("conditioning_lyrics", None) + if cross_attn is not None: + out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics) + + refer_audio = kwargs.get("reference_audio_timbre_latents", None) + if refer_audio is None or len(refer_audio) == 0: + refer_audio = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, 750) + else: + refer_audio = refer_audio[-1] + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + + return out + class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8cea16e50..e8ad725df 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -655,6 +655,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["audio_model"] = "ace1.5" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..722c0c154 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -59,6 +59,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima +import comfy.text_encoders.ace15 import comfy.model_patcher import comfy.lora @@ -452,6 +453,8 @@ class VAE: self.extra_1d_channel = None self.crop_input = True + self.audio_sample_rate = 44100 + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -549,14 +552,25 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: - self.first_stage_model = AudioOobleckVAE() + config = {} + param_key = None + if "decoder.layers.2.layers.1.weight_v" in sd: + param_key = "decoder.layers.2.layers.1.weight_v" + if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: + param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1" + if param_key is not None: + if sd[param_key].shape[-1] == 12: + config["strides"] = [2, 4, 4, 6, 10] + self.audio_sample_rate = 48000 + + self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" self.upscale_ratio = 2048 - self.downscale_ratio = 2048 + self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -1427,6 +1441,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data_jina = clip_data[0] tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) + elif clip_type == CLIPType.ACE: + clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 9ecfc9c55..4c817d468 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass + elif isinstance(layer_idx, list): + self.layer = layer_idx elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..6b7d831cb 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -24,6 +24,7 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima +import comfy.text_encoders.ace15 from . import supported_models_base from . import latent_formats @@ -1596,6 +1597,38 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class ACEStep15(supported_models_base.BASE): + unet_config = { + "audio_model": "ace1.5", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + latent_format = comfy.latent_formats.ACEAudio15 + + memory_usage_factor = 4.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ACEStep15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py new file mode 100644 index 000000000..9070cb577 --- /dev/null +++ b/comfy/text_encoders/ace15.py @@ -0,0 +1,218 @@ +from .anima import Qwen3Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import torch +import math + + +def sample_manual_loop_no_classes( + model, + ids=None, + paddings=[], + execution_dtype=None, + cfg_scale: float = 2.0, + temperature: float = 0.85, + top_p: float = 0.9, + top_k: int = None, + seed: int = 1, + min_tokens: int = 1, + max_new_tokens: int = 2048, + audio_start_id: int = 151669, # The cutoff ID for audio codes + eos_token_id: int = 151645, +): + device = model.execution_device + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + + embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) + for i, t in enumerate(paddings): + attention_mask[i, :t] = 0 + attention_mask[i, t:] = 1 + + output_audio_codes = [] + past_key_values = [] + generator = torch.Generator(device=device) + generator.manual_seed(seed) + model_config = model.transformer.model.config + + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + for step in range(max_new_tokens): + outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) + next_token_logits = model.transformer.logits(outputs[0])[:, -1] + past_key_values = outputs[2] + + cond_logits = next_token_logits[0:1] + uncond_logits = next_token_logits[1:2] + cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + eos_score = cfg_logits[:, eos_token_id].clone() + + # Only generate audio tokens + cfg_logits[:, :audio_start_id] = float('-inf') + + if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: + cfg_logits[:, eos_token_id] = eos_score + + if top_k is not None and top_k > 0: + top_k_vals, _ = torch.topk(cfg_logits, top_k) + min_val = top_k_vals[..., -1, None] + cfg_logits[cfg_logits < min_val] = float('-inf') + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + cfg_logits[indices_to_remove] = float('-inf') + + if temperature > 0: + cfg_logits = cfg_logits / temperature + next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1) + else: + next_token = torch.argmax(cfg_logits, dim=-1) + + token = next_token.item() + + if token == eos_token_id: + break + + embed, _, _, _ = model.process_tokens([[token]], device) + embeds = embed.repeat(2, 1, 1) + attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) + + output_audio_codes.append(token - audio_start_id) + + return output_audio_codes + + +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0): + cfg_scale = 2.0 + + positive = [[token for token, _ in inner_list] for inner_list in positive] + negative = [[token for token, _ in inner_list] for inner_list in negative] + positive = positive[0] + negative = negative[0] + + neg_pad = 0 + if len(negative) < len(positive): + neg_pad = (len(positive) - len(negative)) + negative = [model.special_tokens["pad"]] * neg_pad + negative + + pos_pad = 0 + if len(negative) > len(positive): + pos_pad = (len(negative) - len(positive)) + positive = [model.special_tokens["pad"]] * pos_pad + positive + + paddings = [pos_pad, neg_pad] + return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + + +class ACE15Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + out = {} + lyrics = kwargs.get("lyrics", "") + bpm = kwargs.get("bpm", 120) + duration = kwargs.get("duration", 120) + keyscale = kwargs.get("keyscale", "C major") + timesignature = kwargs.get("timesignature", 2) + language = kwargs.get("language", "en") + seed = kwargs.get("seed", 0) + + duration = math.ceil(duration) + meta_lm = 'bpm: {}\nduration: {}\nkeyscale: {}\ntimesignature: {}'.format(bpm, duration, keyscale, timesignature) + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n{}\n<|im_end|>\n<|im_start|>assistant\n\n{}\n\n\n<|im_end|>\n" + + meta_cap = '- bpm: {}\n- timesignature: {}\n- keyscale: {}\n- duration: {}\n'.format(bpm, timesignature, keyscale, duration) + out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, meta_lm), disable_weights=True) + out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, ""), disable_weights=True) + + out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric{}<|endoftext|><|endoftext|>".format(language, lyrics), return_word_ids, disable_weights=True, **kwargs) + out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}# Metas\n{}<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs) + out["lm_metadata"] = {"min_tokens": duration * 5, "seed": seed} + return out + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_2B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class ACE15TEModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + super().__init__() + if dtype_llama is None: + dtype_llama = dtype + + self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) + self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + self.dtypes = set([dtype, dtype_llama]) + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_base = token_weight_pairs["qwen3_06b"] + token_weight_pairs_lyrics = token_weight_pairs["lyrics"] + + self.qwen3_06b.set_clip_options({"layer": None}) + base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base) + self.qwen3_06b.set_clip_options({"layer": [0]}) + lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + + lm_metadata = token_weight_pairs["lm_metadata"] + audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + + return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} + + def set_clip_options(self, options): + self.qwen3_06b.set_clip_options(options) + self.qwen3_2b.set_clip_options(options) + + def reset_clip_options(self): + self.qwen3_06b.reset_clip_options() + self.qwen3_2b.reset_clip_options() + + def load_sd(self, sd): + if "model.layers.0.post_attention_layernorm.weight" in sd: + shape = sd["model.layers.0.post_attention_layernorm.weight"].shape + if shape[0] == 1024: + return self.qwen3_06b.load_sd(sd) + else: + return self.qwen3_2b.load_sd(sd) + + def memory_estimation_function(self, token_weight_pairs, device=None): + lm_metadata = token_weight_pairs["lm_metadata"] + constant = 0.4375 + if comfy.model_management.should_use_bf16(device): + constant *= 0.5 + + token_weight_pairs = token_weight_pairs.get("lm_prompt", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens += lm_metadata['min_tokens'] + return num_tokens * constant * 1024 * 1024 + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ACE15TEModel_(ACE15TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 68ac1e804..d2324ffc5 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -103,6 +103,52 @@ class Qwen3_06BConfig: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_06B_ACE15_Config: + vocab_size: int = 151669 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_2B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -729,6 +775,27 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06B_ACE15_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_2B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + def logits(self, x): + return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 1409233c9..376584e5c 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -28,12 +28,39 @@ class TextEncodeAceStepAudio(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) return io.NodeOutput(conditioning) +class TextEncodeAceStepAudio15(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio1.5", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("bpm", default=120, min=10, max=300), + io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + io.Combo.Input("timesignature", options=['2', '3', '4', '6']), + io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) + class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="EmptyAceStepLatentAudio", + display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), @@ -51,12 +78,60 @@ class EmptyAceStepLatentAudio(io.ComfyNode): return io.NodeOutput({"samples": latent, "type": "audio"}) +class EmptyAceStep15LatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStep1.5LatentAudio", + display_name="Empty Ace Step 1.5 Latent Audio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> io.NodeOutput: + length = round((seconds * 48000 / 1920)) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) + +class ReferenceTimbreAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReferenceTimbreAudio", + category="advanced/conditioning/audio", + is_experimental=True, + description="This node sets the reference audio for timbre (for ace step 1.5)", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) + return io.NodeOutput(conditioning) + class AceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, + TextEncodeAceStepAudio15, + EmptyAceStep15LatentAudio, + ReferenceTimbreAudio, ] async def comfy_entrypoint() -> AceExtension: diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 271b75fbd..bef723dce 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -82,13 +82,14 @@ class VAEEncodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] - if 44100 != sample_rate: - waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return IO.NodeOutput({"samples":t}) + return IO.NodeOutput({"samples": t}) encode = execute # TODO: remove @@ -114,7 +115,8 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + return IO.NodeOutput({"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/nodes.py b/nodes.py index 1cb43d9e2..e11a8ed80 100644 --- a/nodes.py +++ b/nodes.py @@ -1001,7 +1001,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie", "ace"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From be4345d1c9327c7a9292bc8f4153d35717020ef3 Mon Sep 17 00:00:00 2001 From: ComfyUI Wiki Date: Tue, 3 Feb 2026 15:08:43 +0800 Subject: [PATCH 17/32] chore: update workflow templates to v0.8.31 (#12239) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 3ca417dd8..0c401873a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.37.11 -comfyui-workflow-templates==0.8.27 +comfyui-workflow-templates==0.8.31 comfyui-embedded-docs==0.4.0 torch torchsde From 66e1b07402a84363b174c35d09acfcab0af7b0e5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Feb 2026 02:20:59 -0500 Subject: [PATCH 18/32] ComfyUI v0.12.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index b1ebaa115..bc6076b67 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.11.1" +__version__ = "0.12.0" diff --git a/pyproject.toml b/pyproject.toml index 042f124e4..28aa03067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.11.1" +version = "0.12.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From f5030e26fd72bdf19ab2b6463dc502fa0b2e34ba Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 01:09:30 -0800 Subject: [PATCH 19/32] Add progress bar to ace step. (#12242) --- comfy/text_encoders/ace15.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 9070cb577..75e77151d 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -3,6 +3,7 @@ import comfy.text_encoders.llama from comfy import sd1_clip import torch import math +import comfy.utils def sample_manual_loop_no_classes( @@ -42,6 +43,8 @@ def sample_manual_loop_no_classes( for x in range(model_config.num_hidden_layers): past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + progress_bar = comfy.utils.ProgressBar(max_new_tokens) + for step in range(max_new_tokens): outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) next_token_logits = model.transformer.logits(outputs[0])[:, -1] @@ -90,6 +93,7 @@ def sample_manual_loop_no_classes( attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1) output_audio_codes.append(token - audio_start_id) + progress_bar.update_absolute(step) return output_audio_codes From affe881354d1e41df9ac9394b2d8257faebfdba1 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:07:04 -0800 Subject: [PATCH 20/32] Fix some issues with mac. (#12247) --- comfy/text_encoders/ace15.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 75e77151d..48c29fa9a 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -57,8 +57,9 @@ def sample_manual_loop_no_classes( if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: eos_score = cfg_logits[:, eos_token_id].clone() + remove_logit_value = torch.finfo(cfg_logits.dtype).min # Only generate audio tokens - cfg_logits[:, :audio_start_id] = float('-inf') + cfg_logits[:, :audio_start_id] = remove_logit_value if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: cfg_logits[:, eos_token_id] = eos_score @@ -66,7 +67,7 @@ def sample_manual_loop_no_classes( if top_k is not None and top_k > 0: top_k_vals, _ = torch.topk(cfg_logits, top_k) min_val = top_k_vals[..., -1, None] - cfg_logits[cfg_logits < min_val] = float('-inf') + cfg_logits[cfg_logits < min_val] = remove_logit_value if top_p is not None and top_p < 1.0: sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) @@ -75,7 +76,7 @@ def sample_manual_loop_no_classes( sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - cfg_logits[indices_to_remove] = float('-inf') + cfg_logits[indices_to_remove] = remove_logit_value if temperature > 0: cfg_logits = cfg_logits / temperature From 223364743c35d6e1dc4f5ebc3796234d0f8484cf Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:31:36 -0800 Subject: [PATCH 21/32] llama: cast logits as a comfy-weight (#12248) This is using a different layers weight with .to(). Change it to use the ops caster if the original layer is a comfy weight so that it picks up dynamic_vram and async_offload functionality in full. Co-authored-by: Rattus --- comfy/text_encoders/llama.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d2324ffc5..d1c628d20 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -6,6 +6,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management +import comfy.ops import comfy.ldm.common_dit import comfy.clip_model @@ -794,7 +795,19 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.dtype = dtype def logits(self, x): - return torch.nn.functional.linear(x[:, -1:], self.model.embed_tokens.weight.to(x), None) + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): From 85fc35e8fa44c6174425acb4f9167792bcc903a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 09:19:39 -0800 Subject: [PATCH 22/32] Fix mac issue. (#12250) --- comfy/text_encoders/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d1c628d20..f4ac224c6 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -628,10 +628,10 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min) if seq_len > 1: - causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min).triu_(1) if mask is not None: mask += causal_mask else: From fb23935c1139875d47e7b6e239f657bae73bda5b Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 3 Feb 2026 20:31:46 +0200 Subject: [PATCH 23/32] feat(comfy_api): add basic 3D Model file types (#12129) * feat(comfy_api): add basic 3D Model file types * update Tripo nodes to use File3DGLB * update Rodin3D nodes to use File3DGLB * address PR review feedback: - Rename File3D parameter 'path' to 'source' - Convert File3D.data property to get_data() - Make .glb extension check case-insensitive in nodes_rodin.py - Restrict SaveGLB node to only accept File3DGLB * Fixed a bug in the Meshy Rig and Animation nodes * Fix backward compatability --- comfy_api/latest/__init__.py | 3 +- comfy_api/latest/_io.py | 52 +++++++++- comfy_api/latest/_util/__init__.py | 3 +- comfy_api/latest/_util/geometry_types.py | 77 ++++++++++++++ comfy_api_nodes/apis/meshy.py | 5 + comfy_api_nodes/nodes_hunyuan3d.py | 45 ++++---- comfy_api_nodes/nodes_meshy.py | 116 ++++++++++++++------- comfy_api_nodes/nodes_rodin.py | 67 ++++++++---- comfy_api_nodes/nodes_tripo.py | 126 +++++++++++------------ comfy_api_nodes/util/__init__.py | 2 + comfy_api_nodes/util/download_helpers.py | 38 ++++++- comfy_extras/nodes_hunyuan3d.py | 27 ++++- comfy_extras/nodes_load_3d.py | 26 ++++- 13 files changed, 427 insertions(+), 160 deletions(-) diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b0fa14ff6..8542a1dbc 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -105,6 +105,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + File3D = File3D ComfyAPI = ComfyAPI_latest diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index eeea9781a..93cf482ca 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG +from ._util import MESH, VOXEL, SVG as _SVG, File3D class FolderType(str, Enum): @@ -667,6 +667,49 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH + +@comfytype(io_type="FILE_3D") +class File3DAny(ComfyTypeIO): + """General 3D file type - accepts any supported 3D format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLB") +class File3DGLB(ComfyTypeIO): + """GLB format 3D file - binary glTF, best for web and cross-platform.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLTF") +class File3DGLTF(ComfyTypeIO): + """GLTF format 3D file - JSON-based glTF with external resources.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_FBX") +class File3DFBX(ComfyTypeIO): + """FBX format 3D file - best for game engines and animation.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_OBJ") +class File3DOBJ(ComfyTypeIO): + """OBJ format 3D file - simple geometry format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_STL") +class File3DSTL(ComfyTypeIO): + """STL format 3D file - best for 3D printing.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_USDZ") +class File3DUSDZ(ComfyTypeIO): + """USDZ format 3D file - Apple AR format.""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -2037,6 +2080,13 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "File3DAny", + "File3DGLB", + "File3DGLTF", + "File3DFBX", + "File3DOBJ", + "File3DSTL", + "File3DUSDZ", "Hooks", "HookKeyframes", "TimestepsRange", diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 6313eb01b..115baf392 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH +from .geometry_types import VOXEL, MESH, File3D from .image_types import SVG __all__ = [ @@ -9,5 +9,6 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index 385122778..b586fceb3 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -1,3 +1,8 @@ +import shutil +from io import BytesIO +from pathlib import Path +from typing import IO + import torch @@ -10,3 +15,75 @@ class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): self.vertices = vertices self.faces = faces + + +class File3D: + """Class representing a 3D file from a file path or binary stream. + + Supports both disk-backed (file path) and memory-backed (BytesIO) storage. + """ + + def __init__(self, source: str | IO[bytes], file_format: str = ""): + self._source = source + self._format = file_format or self._infer_format() + + def _infer_format(self) -> str: + if isinstance(self._source, str): + return Path(self._source).suffix.lstrip(".").lower() + return "" + + @property + def format(self) -> str: + return self._format + + @format.setter + def format(self, value: str) -> None: + self._format = value.lstrip(".").lower() if value else "" + + @property + def is_disk_backed(self) -> bool: + return isinstance(self._source, str) + + def get_source(self) -> str | IO[bytes]: + if isinstance(self._source, str): + return self._source + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source + + def get_data(self) -> BytesIO: + if isinstance(self._source, str): + with open(self._source, "rb") as f: + result = BytesIO(f.read()) + return result + if hasattr(self._source, "seek"): + self._source.seek(0) + if isinstance(self._source, BytesIO): + return self._source + return BytesIO(self._source.read()) + + def save_to(self, path: str) -> str: + dest = Path(path) + dest.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(self._source, str): + if Path(self._source).resolve() != dest.resolve(): + shutil.copy2(self._source, dest) + else: + if hasattr(self._source, "seek"): + self._source.seek(0) + with open(dest, "wb") as f: + f.write(self._source.read()) + return str(dest) + + def get_bytes(self) -> bytes: + if isinstance(self._source, str): + return Path(self._source).read_bytes() + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source.read() + + def __repr__(self) -> str: + if isinstance(self._source, str): + return f"File3D(source={self._source!r}, format={self._format!r})" + return f"File3D(, format={self._format!r})" diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py index be46d0d58..7d72e6e91 100644 --- a/comfy_api_nodes/apis/meshy.py +++ b/comfy_api_nodes/apis/meshy.py @@ -109,14 +109,19 @@ class MeshyTextureRequest(BaseModel): class MeshyModelsUrls(BaseModel): glb: str = Field("") + fbx: str = Field("") + usdz: str = Field("") + obj: str = Field("") class MeshyRiggedModelsUrls(BaseModel): rigged_character_glb_url: str = Field("") + rigged_character_fbx_url: str = Field("") class MeshyAnimatedModelsUrls(BaseModel): animation_glb_url: str = Field("") + animation_fbx_url: str = Field("") class MeshyResultTextureUrls(BaseModel): diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index b3a736643..813a7c809 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -1,5 +1,3 @@ -import os - from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -14,7 +12,7 @@ from comfy_api_nodes.apis.hunyuan3d import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_to_bytesio, + download_url_to_file_3d, downscale_image_tensor_by_max_side, poll_op, sync_op, @@ -22,14 +20,13 @@ from comfy_api_nodes.util import ( validate_image_dimensions, validate_string, ) -from folder_paths import get_output_directory -def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D: +def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None: for i in response_objs: - if i.Type.lower() == "glb": + if i.Type.lower() == file_type.lower(): return i - raise ValueError("No GLB file found in response. Please report this to the developers.") + return None class TencentTextToModelNode(IO.ComfyNode): @@ -74,7 +71,9 @@ class TencentTextToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -124,19 +123,20 @@ class TencentTextToModelNode(IO.ComfyNode): ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId result = await poll_op( cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), + data=To3DProTaskQueryRequest(JobId=task_id), response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - model_file = f"hunyuan_model_{response.JobId}.glb" - await download_url_to_bytesio( - get_glb_obj_from_response(result.ResultFile3Ds).Url, - os.path.join(get_output_directory(), model_file), + glb_result = get_file_from_response(result.ResultFile3Ds, "glb") + obj_result = get_file_from_response(result.ResultFile3Ds, "obj") + file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None + return IO.NodeOutput( + file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None ) - return IO.NodeOutput(model_file) class TencentImageToModelNode(IO.ComfyNode): @@ -184,7 +184,9 @@ class TencentImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -269,19 +271,20 @@ class TencentImageToModelNode(IO.ComfyNode): ) if response.Error: raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId result = await poll_op( cls, ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), - data=To3DProTaskQueryRequest(JobId=response.JobId), + data=To3DProTaskQueryRequest(JobId=task_id), response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - model_file = f"hunyuan_model_{response.JobId}.glb" - await download_url_to_bytesio( - get_glb_obj_from_response(result.ResultFile3Ds).Url, - os.path.join(get_output_directory(), model_file), + glb_result = get_file_from_response(result.ResultFile3Ds, "glb") + obj_result = get_file_from_response(result.ResultFile3Ds, "obj") + file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None + return IO.NodeOutput( + file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None ) - return IO.NodeOutput(model_file) class TencentHunyuan3DExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 740607983..65f6f0d2d 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -1,5 +1,3 @@ -import os - from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input @@ -20,13 +18,12 @@ from comfy_api_nodes.apis.meshy import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_to_bytesio, + download_url_to_file_3d, poll_op, sync_op, upload_images_to_comfyapi, validate_string, ) -from folder_paths import get_output_directory class MeshyTextToModelNode(IO.ComfyNode): @@ -79,8 +76,10 @@ class MeshyTextToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -122,16 +121,20 @@ class MeshyTextToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyRefineNode(IO.ComfyNode): @@ -167,8 +170,10 @@ class MeshyRefineNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -210,16 +215,20 @@ class MeshyRefineNode(IO.ComfyNode): ai_model=model, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyImageToModelNode(IO.ComfyNode): @@ -303,8 +312,10 @@ class MeshyImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -368,16 +379,20 @@ class MeshyImageToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyMultiImageToModelNode(IO.ComfyNode): @@ -464,8 +479,10 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -531,16 +548,20 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): seed=seed, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyRigModelNode(IO.ComfyNode): @@ -571,8 +592,10 @@ class MeshyRigModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -606,18 +629,20 @@ class MeshyRigModelNode(IO.ComfyNode): texture_image_url=texture_image_url, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"), response_model=MeshyRiggedResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio( - result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id), ) - return IO.NodeOutput(model_file, response.result) class MeshyAnimateModelNode(IO.ComfyNode): @@ -640,7 +665,9 @@ class MeshyAnimateModelNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -669,16 +696,19 @@ class MeshyAnimateModelNode(IO.ComfyNode): action_id=action_id, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"), response_model=MeshyAnimationResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id), + ) class MeshyTextureNode(IO.ComfyNode): @@ -715,8 +745,10 @@ class MeshyTextureNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -760,16 +792,20 @@ class MeshyTextureNode(IO.ComfyNode): image_style_url=image_style_url, ), ) + task_id = response.result result = await poll_op( cls, - ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"), + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"), response_model=MeshyModelResult, status_extractor=lambda r: r.status, progress_extractor=lambda r: r.progress, ) - model_file = f"meshy_model_{response.result}.glb" - await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) - return IO.NodeOutput(model_file, response.result) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) class MeshyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 3ffdc8b90..f9cff121f 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -10,7 +10,6 @@ import folder_paths as comfy_paths import os import logging import math -from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image @@ -28,8 +27,9 @@ from comfy_api_nodes.util import ( poll_op, ApiEndpoint, download_url_to_bytesio, + download_url_to_file_3d, ) -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types COMMON_PARAMETERS = [ @@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" -def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: +def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None: if not response.jobs: return None completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) @@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D ) -async def download_files(url_list, task_uuid: str): +async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]: result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None + file_3d = None + for i in url_list.list: file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): + if i.name.lower().endswith(".glb"): model_file_path = os.path.join(result_folder_name, i.name) - await download_url_to_bytesio(i.url, file_path) - return model_file_path + file_3d = await download_url_to_file_3d(i.url, "glb") + # Save to disk for backward compatibility + with open(file_path, "wb") as f: + f.write(file_3d.get_bytes()) + else: + await download_url_to_bytesio(i.url, file_path) + + return model_file_path, file_3d class Rodin3D_Regular(IO.ComfyNode): @@ -234,7 +242,10 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -271,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Detail(IO.ComfyNode): @@ -290,7 +301,10 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -327,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Smooth(IO.ComfyNode): @@ -346,7 +360,10 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -382,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Sketch(IO.ComfyNode): @@ -408,7 +425,10 @@ class Rodin3D_Sketch(IO.ComfyNode): optional=True, ), ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -441,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Gen2(IO.ComfyNode): @@ -475,7 +495,10 @@ class Rodin3D_Gen2(IO.ComfyNode): ), IO.Boolean.Input("TAPose", default=False), ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, @@ -511,9 +534,9 @@ class Rodin3D_Gen2(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3DExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 5abf27b4d..67c7f59fc 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,10 +1,6 @@ -import os -from typing import Optional - -import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, @@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_as_bytesio, + download_url_to_file_3d, poll_op, sync_op, upload_images_to_comfyapi, ) -from folder_paths import get_output_directory def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( node_cls: type[IO.ComfyNode], response: TripoTaskResponse, - average_duration: Optional[int] = None, + average_duration: int | None = None, ) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: @@ -69,12 +64,8 @@ async def poll_until_finished( ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_as_bytesio(url) - # Save the downloaded model file - model_file = f"tripo_model_{task_id}.glb" - with open(os.path.join(get_output_directory(), model_file), "wb") as f: - f.write(bytesio.getvalue()) - return IO.NodeOutput(model_file, task_id) + file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id) + return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb) raise RuntimeError(f"Failed to generate mesh: {response_poll}") @@ -107,8 +98,9 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -155,18 +147,18 @@ class TripoTextToModelNode(IO.ComfyNode): async def execute( cls, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str | None = None, model_version=None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - image_seed: Optional[int] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + image_seed: int | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: @@ -232,8 +224,9 @@ class TripoImageToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -279,19 +272,19 @@ class TripoImageToModelNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - model_version: Optional[str] = None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, + image: Input.Image, + model_version: str | None = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, orientation=None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: @@ -368,8 +361,9 @@ class TripoMultiviewToModelNode(IO.ComfyNode): IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -411,21 +405,21 @@ class TripoMultiviewToModelNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, - image_left: Optional[torch.Tensor] = None, - image_back: Optional[torch.Tensor] = None, - image_right: Optional[torch.Tensor] = None, - model_version: Optional[str] = None, - orientation: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + image: Input.Image, + image_left: Input.Image | None = None, + image_back: Input.Image | None = None, + image_right: Input.Image | None = None, + model_version: str | None = None, + orientation: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") @@ -487,8 +481,9 @@ class TripoTextureNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -512,11 +507,11 @@ class TripoTextureNode(IO.ComfyNode): async def execute( cls, model_task_id, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, + texture: bool | None = None, + pbr: bool | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + texture_alignment: str | None = None, ) -> IO.NodeOutput: response = await sync_op( cls, @@ -547,8 +542,9 @@ class TripoRefineNode(IO.ComfyNode): IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -583,8 +579,9 @@ class TripoRigNode(IO.ComfyNode): category="api node/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -642,8 +639,9 @@ class TripoRetargetNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index c3c9ff4bf..18b020eef 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -28,6 +28,7 @@ from .conversions import ( from .download_helpers import ( download_url_as_bytesio, download_url_to_bytesio, + download_url_to_file_3d, download_url_to_image_tensor, download_url_to_video_output, ) @@ -69,6 +70,7 @@ __all__ = [ # Download helpers "download_url_as_bytesio", "download_url_to_bytesio", + "download_url_to_file_3d", "download_url_to_image_tensor", "download_url_to_video_output", # Conversions diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 4668d14a9..78bcf1fa1 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -11,7 +11,8 @@ import torch from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.latest import IO as COMFY_IO -from comfy_api.latest import InputImpl +from comfy_api.latest import InputImpl, Types +from folder_paths import get_output_directory from . import request_logger from ._helpers import ( @@ -261,3 +262,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str: except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +async def download_url_to_file_3d( + url: str, + file_format: str, + *, + task_id: str | None = None, + timeout: float | None = None, + max_retries: int = 5, + cls: type[COMFY_IO.ComfyNode] = None, +) -> Types.File3D: + """Downloads a 3D model file from a URL into memory as BytesIO. + + If task_id is provided, also writes the file to disk in the output directory + for backward compatibility with the old save-to-disk behavior. + """ + file_format = file_format.lstrip(".").lower() + data = BytesIO() + await download_url_to_bytesio( + url, + data, + timeout=timeout, + max_retries=max_retries, + cls=cls, + ) + + if task_id is not None: + # This is only for backward compatability with current behavior when every 3D node is output node + # All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results + output_dir = Path(get_output_directory()) + output_path = output_dir / f"{task_id}.{file_format}" + output_path.write_bytes(data.getvalue()) + data.seek(0) + + return Types.File3D(source=data, file_format=file_format) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 5bb5df48e..eda1639ab 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -622,14 +622,20 @@ class SaveGLB(IO.ComfyNode): category="3d", is_output_node=True, inputs=[ - IO.Mesh.Input("mesh"), + IO.MultiType.Input( + IO.Mesh.Input("mesh"), + types=[ + IO.File3DGLB, + ], + tooltip="Mesh or GLB file to save", + ), IO.String.Input("filename_prefix", default="mesh/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] ) @classmethod - def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: + def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] @@ -641,15 +647,26 @@ class SaveGLB(IO.ComfyNode): for x in cls.hidden.extra_pnginfo: metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) - for i in range(mesh.vertices.shape[0]): + if isinstance(mesh, Types.File3D): + # Handle File3D input - save BytesIO data to output folder f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + mesh.save_to(os.path.join(full_output_folder, f)) results.append({ "filename": f, "subfolder": subfolder, "type": "output" }) - counter += 1 + else: + # Handle Mesh input - save vertices and faces as GLB + for i in range(mesh.vertices.shape[0]): + f = f"{filename}_{counter:05}_.glb" + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + results.append({ + "filename": f, + "subfolder": subfolder, + "type": "output" + }) + counter += 1 return IO.NodeOutput(ui={"3d": results}) diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 4b8d950ae..f29510488 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -1,9 +1,10 @@ import nodes import folder_paths import os +import uuid from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, InputImpl, UI +from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types from pathlib import Path @@ -81,7 +82,19 @@ class Preview3D(IO.ComfyNode): is_experimental=True, is_output_node=True, inputs=[ - IO.String.Input("model_file", default="", multiline=False), + IO.MultiType.Input( + IO.String.Input("model_file", default="", multiline=False), + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DFBX, + IO.File3DOBJ, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="3D model file or path string", + ), IO.Load3DCamera.Input("camera_info", optional=True), IO.Image.Input("bg_image", optional=True), ], @@ -89,10 +102,15 @@ class Preview3D(IO.ComfyNode): ) @classmethod - def execute(cls, model_file, **kwargs) -> IO.NodeOutput: + def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput: + if isinstance(model_file, Types.File3D): + filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}" + model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + else: + filename = model_file camera_info = kwargs.get("camera_info", None) bg_image = kwargs.get("bg_image", None) - return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) + return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image)) process = execute # TODO: remove From ab1050bec3ecb89d1bc06c13e67626805bd8e1ee Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 10:54:23 -0800 Subject: [PATCH 24/32] Support ace step 1.5 base model loras. (#12252) --- comfy/lora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/comfy/lora.py b/comfy/lora.py index 7b31d055c..44030bcab 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -332,6 +332,12 @@ def model_lora_keys_unet(model, key_map={}): key_map["{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.ACEStep15): + for k in sdk: + if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model.decoder."):-len(".weight")] + key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + return key_map From b8315e66cbcf675d683ed83892314cbd3a3d1bf7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 11:40:45 -0800 Subject: [PATCH 25/32] Fix tiled vae for ace step 1.5 (#12253) --- comfy/sd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 722c0c154..aa93e101d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -554,6 +554,8 @@ class VAE: elif "decoder.layers.1.layers.0.beta" in sd: config = {} param_key = None + self.upscale_ratio = 2048 + self.downscale_ratio = 2048 if "decoder.layers.2.layers.1.weight_v" in sd: param_key = "decoder.layers.2.layers.1.weight_v" if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: @@ -562,6 +564,8 @@ class VAE: if sd[param_key].shape[-1] == 12: config["strides"] = [2, 4, 4, 6, 10] self.audio_sample_rate = 48000 + self.upscale_ratio = 1920 + self.downscale_ratio = 1920 self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) @@ -569,8 +573,6 @@ class VAE: self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" - self.upscale_ratio = 2048 - self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -870,7 +872,7 @@ class VAE: / 3.0) return output - def decode_tiled_1d(self, samples, tile_x=128, overlap=32): + def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() else: From 3be017516643d9a8d8ed85904512ed4b240ef8c6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Feb 2026 14:13:18 -0500 Subject: [PATCH 26/32] ComfyUI v0.12.1 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index bc6076b67..2e2c12ced 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.0" +__version__ = "0.12.1" diff --git a/pyproject.toml b/pyproject.toml index 28aa03067..c21ee03f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.0" +version = "0.12.1" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From fe2511468d6b3f7a1581a711ae480f3313f2b39d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 16:01:38 -0800 Subject: [PATCH 27/32] Support the 4B ace step 1.5 lm model. (#12257) Can be used as an alternative to the 1.7B --- comfy/sd.py | 7 +++- comfy/supported_models.py | 12 +++++- comfy/text_encoders/ace15.py | 42 ++++++++++++++++----- comfy/text_encoders/llama.py | 72 ++++++++++++++++++++++++++---------- 4 files changed, 101 insertions(+), 32 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index aa93e101d..bc63d6ced 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1444,7 +1444,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) elif clip_type == CLIPType.ACE: - clip_target.clip = comfy.text_encoders.ace15.te(**llama_detect(clip_data)) + te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])] + if TEModel.QWEN3_4B in te_models: + model_type = "qwen3_4b" + else: + model_type = "qwen3_2b" + clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6b7d831cb..77264ed28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1625,8 +1625,16 @@ class ACEStep15(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] - hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**hunyuan_detect)) + detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if "dtype_llama" in detect_2b: + detect = detect_2b + detect["lm_model"] = "qwen3_2b" + elif "dtype_llama" in detect_4b: + detect = detect_4b + detect["lm_model"] = "qwen3_4b" + + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 48c29fa9a..73d710671 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -162,14 +162,34 @@ class Qwen3_2B_ACE15(sd1_clip.SDClipModel): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Qwen3_4B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + class ACE15TEModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, dtype_llama=None, model_options={}): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}): super().__init__() if dtype_llama is None: dtype_llama = dtype + model = None + self.constant = 0.4375 + if lm_model == "qwen3_4b": + model = Qwen3_4B_ACE15 + self.constant = 0.5625 + elif lm_model == "qwen3_2b": + model = Qwen3_2B_ACE15 + + self.lm_model = lm_model self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) - self.qwen3_2b = Qwen3_2B_ACE15(device=device, dtype=dtype_llama, model_options=model_options) + if model is not None: + setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options)) + self.dtypes = set([dtype, dtype_llama]) def encode_token_weights(self, token_weight_pairs): @@ -182,17 +202,21 @@ class ACE15TEModel(torch.nn.Module): lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) lm_metadata = token_weight_pairs["lm_metadata"] - audio_codes = generate_audio_codes(self.qwen3_2b, token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"]) return base_out, None, {"conditioning_lyrics": lyrics_embeds[:, 0], "audio_codes": [audio_codes]} def set_clip_options(self, options): self.qwen3_06b.set_clip_options(options) - self.qwen3_2b.set_clip_options(options) + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.set_clip_options(options) def reset_clip_options(self): self.qwen3_06b.reset_clip_options() - self.qwen3_2b.reset_clip_options() + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.reset_clip_options() def load_sd(self, sd): if "model.layers.0.post_attention_layernorm.weight" in sd: @@ -200,11 +224,11 @@ class ACE15TEModel(torch.nn.Module): if shape[0] == 1024: return self.qwen3_06b.load_sd(sd) else: - return self.qwen3_2b.load_sd(sd) + return getattr(self, self.lm_model).load_sd(sd) def memory_estimation_function(self, token_weight_pairs, device=None): lm_metadata = token_weight_pairs["lm_metadata"] - constant = 0.4375 + constant = self.constant if comfy.model_management.should_use_bf16(device): constant *= 0.5 @@ -213,11 +237,11 @@ class ACE15TEModel(torch.nn.Module): num_tokens += lm_metadata['min_tokens'] return num_tokens * constant * 1024 * 1024 -def te(dtype_llama=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): class ACE15TEModel_(ACE15TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: model_options = model_options.copy() model_options["llama_quantization_metadata"] = llama_quantization_metadata - super().__init__(device=device, dtype_llama=dtype_llama, dtype=dtype, model_options=model_options) + super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options) return ACE15TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index f4ac224c6..3afd094d1 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -150,6 +150,29 @@ class Qwen3_2B_ACE15_lm_Config: final_norm: bool = True lm_head: bool = False +@dataclass +class Qwen3_4B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + @dataclass class Qwen3_4BConfig: vocab_size: int = 151936 @@ -739,6 +762,21 @@ class BaseLlama: def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseQwen3: + def logits(self, x): + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Llama2(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): @@ -767,7 +805,7 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B(BaseLlama, torch.nn.Module): +class Qwen3_06B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06BConfig(**config_dict) @@ -776,7 +814,7 @@ class Qwen3_06B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): +class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_06B_ACE15_Config(**config_dict) @@ -785,7 +823,7 @@ class Qwen3_06B_ACE15(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): +class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_2B_ACE15_lm_Config(**config_dict) @@ -794,22 +832,7 @@ class Qwen3_2B_ACE15_lm(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype - def logits(self, x): - input = x[:, -1:] - module = self.model.embed_tokens - - offload_stream = None - if module.comfy_cast_weights: - weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) - else: - weight = self.model.embed_tokens.weight.to(x) - - x = torch.nn.functional.linear(input, weight, None) - - comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) - return x - -class Qwen3_4B(BaseLlama, torch.nn.Module): +class Qwen3_4B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_4BConfig(**config_dict) @@ -818,7 +841,16 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_8B(BaseLlama, torch.nn.Module): +class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_8B(BaseLlama, BaseQwen3, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_8BConfig(**config_dict) From 855849c6588180fec88186127aae1a3299387fa6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Feb 2026 18:39:19 -0800 Subject: [PATCH 28/32] mm: Remove Aimdo exemption for empty_cache (#12260) Its more important to get the torch caching allocator GC up and running than supporting the pyt2.7 bug. Switch it on. Defeature dynamic_vram + pyt2.7. --- comfy/model_management.py | 8 +++----- main.py | 7 +++++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 72348258b..b6291f340 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1724,11 +1724,9 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/main.py b/main.py index b8c951375..92d705b4d 100644 --- a/main.py +++ b/main.py @@ -192,7 +192,10 @@ import comfy_aimdo.control import comfy_aimdo.torch if enables_dynamic_vram(): - if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if comfy.model_management.torch_version_numeric < (2, 8): + logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): if args.verbose == 'DEBUG': comfy_aimdo.control.set_log_debug() elif args.verbose == 'CRITICAL': @@ -208,7 +211,7 @@ if enables_dynamic_vram(): comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() logging.info("DynamicVRAM support detected and enabled") else: - logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") comfy.memory_management.aimdo_allocator = None From a31681564d9bd49530df593d87048232acecd842 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 3 Feb 2026 21:03:21 -0800 Subject: [PATCH 29/32] Fix crash with ace step 1.5 (#12264) --- comfy/text_encoders/ace15.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 73d710671..fce2b67ce 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -19,6 +19,7 @@ def sample_manual_loop_no_classes( min_tokens: int = 1, max_new_tokens: int = 2048, audio_start_id: int = 151669, # The cutoff ID for audio codes + audio_end_id: int = 215669, eos_token_id: int = 151645, ): device = model.execution_device @@ -60,6 +61,7 @@ def sample_manual_loop_no_classes( remove_logit_value = torch.finfo(cfg_logits.dtype).min # Only generate audio tokens cfg_logits[:, :audio_start_id] = remove_logit_value + cfg_logits[:, audio_end_id:] = remove_logit_value if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step: cfg_logits[:, eos_token_id] = eos_score From 5087f1d497c5b615fbb5d1ff03fcc1df308bd025 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 4 Feb 2026 00:08:59 -0500 Subject: [PATCH 30/32] ComfyUI v0.12.2 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 2e2c12ced..5d296cd1b 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.12.1" +__version__ = "0.12.2" diff --git a/pyproject.toml b/pyproject.toml index c21ee03f1..1ddcc3596 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.12.1" +version = "0.12.2" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From d30c609f5aa1d0cc75852815308adbb6ae21c644 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:48:47 -0800 Subject: [PATCH 31/32] utils: safetensors: dont slice data on torch level (#12266) Torch has alignment enforcement when viewing with data type changes but only relative to itself. Do all tensor constructions straight off the memory-view individually so pytorch doesnt see an alignment problem. The is needed for handling misaligned safetensors weights, which are reasonably common in third party models. This limits usage of this safetensors loader to GPU compute only as CPUs kernnel are very likely to bus error. But it works for dynamic_vram, where we really dont want to take a deep copy and we always use GPU copy_ which disentangles the misalignment. --- comfy/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index c1b536833..1337e2205 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -82,14 +82,12 @@ _TYPES = { def load_safetensors(ckpt): f = open(ckpt, "rb") mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + mv = memoryview(mapping) header_size = struct.unpack(" Date: Tue, 3 Feb 2026 23:08:45 -0800 Subject: [PATCH 32/32] mp: Fix checkpoint saving (#12268) Fix regression in the recent model saving refactor. Pass the non unet pieces down the layers so that checkpoints are complete. --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cdf289395..d888dbcfb 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1400,7 +1400,7 @@ class ModelPatcher: continue key = "diffusion_model." + k unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) - return self.model.state_dict_for_saving(unet_state_dict) + return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) def __del__(self): self.unpin_all_weights()