From aa6f7a83bb313f121ee032a8b21aa72ea32cfee1 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 31 Jan 2026 17:05:11 -0800 Subject: [PATCH 1/6] Send is_input_list on v1 and v3 schema to frontend (#12188) --- comfy_api/latest/_io.py | 2 ++ server.py | 1 + 2 files changed, 3 insertions(+) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 78f77d4b2..eeea9781a 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1248,6 +1248,7 @@ class Hidden(str, Enum): class NodeInfoV1: input: dict=None input_order: dict[str, list[str]]=None + is_input_list: bool=None output: list[str]=None output_is_list: list[bool]=None output_name: list[str]=None @@ -1474,6 +1475,7 @@ class Schema: info = NodeInfoV1( input=input, input_order={key: list(value.keys()) for (key, value) in input.items()}, + is_input_list=self.is_input_list, output=output, output_is_list=output_is_list, output_name=output_name, diff --git a/server.py b/server.py index 2aee5cc06..2300393b2 100644 --- a/server.py +++ b/server.py @@ -656,6 +656,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False) info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] From 873de5f37a1f4c39c0203b2989e6736b0cb12b98 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 18:11:11 -0800 Subject: [PATCH 2/6] KV cache implementation for using llama models for text generation. (#12195) --- comfy/text_encoders/llama.py | 91 +++++++++++++++++++++++++++++------- 1 file changed, 74 insertions(+), 17 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 3080a3e09..68ac1e804 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, Tuple import math from comfy.ldm.modules.attention import optimized_attention_for_device @@ -32,6 +32,7 @@ class Llama2Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Mistral3Small24BConfig: @@ -54,6 +55,7 @@ class Mistral3Small24BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_3BConfig: @@ -76,6 +78,7 @@ class Qwen25_3BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_06BConfig: @@ -98,6 +101,7 @@ class Qwen3_06BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_4BConfig: @@ -120,6 +124,7 @@ class Qwen3_4BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_8BConfig: @@ -142,6 +147,7 @@ class Qwen3_8BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Ovis25_2BConfig: @@ -164,6 +170,7 @@ class Ovis25_2BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_7BVLI_Config: @@ -186,6 +193,7 @@ class Qwen25_7BVLI_Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma2_2B_Config: @@ -209,6 +217,7 @@ class Gemma2_2B_Config: sliding_attention = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_4B_Config: @@ -232,6 +241,7 @@ class Gemma3_4B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_12B_Config: @@ -255,6 +265,7 @@ class Gemma3_12B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} mm_tokens_per_image = 256 @@ -356,6 +367,7 @@ class Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) @@ -373,11 +385,30 @@ class Attention(nn.Module): xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) + return self.o_proj(output), present_key_value class MLP(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -408,15 +439,17 @@ class TransformerBlock(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = residual + x @@ -426,7 +459,7 @@ class TransformerBlock(nn.Module): x = self.mlp(x) x = residual + x - return x + return x, present_key_value class TransformerBlockGemma2(nn.Module): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): @@ -451,6 +484,7 @@ class TransformerBlockGemma2(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if self.transformer_type == 'gemma3': if self.sliding_attention: @@ -468,11 +502,12 @@ class TransformerBlockGemma2(nn.Module): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = self.post_attention_layernorm(x) @@ -485,7 +520,7 @@ class TransformerBlockGemma2(nn.Module): x = self.post_feedforward_layernorm(x) x = residual + x - return x + return x, present_key_value class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): @@ -516,9 +551,10 @@ class Llama2_(nn.Module): else: self.norm = None - # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + if config.lm_head: + self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): if embeds is not None: x = embeds else: @@ -527,8 +563,13 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = past_key_values[0][2] + if position_ids is None: - position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, @@ -539,14 +580,16 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None @@ -562,16 +605,27 @@ class Llama2_(nn.Module): elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output + next_key_values = [] for i, layer in enumerate(self.layers): if all_intermediate is not None: if only_layers is None or (i in only_layers): all_intermediate.append(x.unsqueeze(1).clone()) - x = layer( + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( x=x, attention_mask=mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_kv, ) + + if current_kv is not None: + next_key_values.append(current_kv) + if i == intermediate_output: intermediate = x.clone() @@ -588,7 +642,10 @@ class Llama2_(nn.Module): if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) - return x, intermediate + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate class Gemma3MultiModalProjector(torch.nn.Module): From f8acd9c402f41efc2c748503e4c5d9e0027965ab Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:01:11 -0800 Subject: [PATCH 3/6] Reduce RAM usage, fix VRAM OOMs, and fix Windows shared memory spilling with adaptive model loading (#11845) --- comfy/audio_encoders/audio_encoders.py | 4 +- comfy/cli_args.py | 4 + comfy/clip_vision.py | 4 +- comfy/controlnet.py | 2 +- comfy/k_diffusion/sampling.py | 33 ++- comfy/ldm/hunyuan_video/upsampler.py | 4 +- comfy/memory_management.py | 81 +++++++ comfy/model_base.py | 15 +- comfy/model_management.py | 170 ++++++++++++-- comfy/model_patcher.py | 313 +++++++++++++++++++++++-- comfy/ops.py | 221 +++++++++++++++-- comfy/pinned_memory.py | 30 +++ comfy/samplers.py | 3 +- comfy/sd.py | 59 +++-- comfy/sd1_clip.py | 2 +- comfy/text_encoders/lt.py | 2 +- comfy/utils.py | 80 ++++++- comfy/windows.py | 52 ++++ comfy_extras/nodes_model_patch.py | 6 +- cuda_malloc.py | 12 +- execution.py | 16 +- main.py | 30 ++- requirements.txt | 1 + 23 files changed, 1030 insertions(+), 114 deletions(-) create mode 100644 comfy/memory_management.py create mode 100644 comfy/pinned_memory.py create mode 100644 comfy/windows.py diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 1716c3de7..63daca861 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -257,3 +258,6 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index b28bf636c..1691fca81 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -47,10 +47,10 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0949dee44..c0c51d51a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,11 +1,12 @@ import math +import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import trange as trange_, tqdm from . import utils from . import deis @@ -13,6 +14,36 @@ from . import sa_solver import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management + + +def trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange_(*args, **kwargs) + + pbar = trange_(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 51b6d1da8..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -109,10 +109,10 @@ class HunyuanVideo15SRModel(): self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..858bd4cc7 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,81 @@ +import math +import torch +from typing import NamedTuple + +from comfy.quant_ops import QuantizedTensor + +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views + +aimdo_allocator = None diff --git a/comfy/model_base.py b/comfy/model_base.py index 66e52864d..85acdb66a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -149,6 +149,8 @@ class BaseModel(torch.nn.Module): self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -299,7 +301,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -307,7 +309,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -322,7 +324,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -330,10 +332,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -776,8 +775,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] diff --git a/comfy/model_management.py b/comfy/model_management.py index 9d39be7b2..758e718e8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,13 @@ import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.memory_management +import comfy.utils +import comfy.quant_ops + +import comfy_aimdo.torch +import comfy_aimdo.model_vbar class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -578,9 +585,15 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -592,7 +605,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -607,15 +620,23 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - get_free_ram() + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) + if ram_to_free > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -650,7 +671,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -676,19 +700,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: @@ -732,6 +762,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -749,6 +782,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] @@ -792,7 +830,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: return torch_dev else: return cpu_dev @@ -1051,6 +1089,53 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() + STREAM_CAST_BUFFERS.clear() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1093,7 +1178,53 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] + + if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + r.copy_(weight, non_blocking=non_blocking) + + if signature is not None: + weight._v_signature = signature + v_tensor.copy_(r) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1112,10 +1243,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1135,7 +1268,7 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: @@ -1557,8 +1690,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f6b80a40f..57b53d8c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP - -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -212,6 +200,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -269,6 +278,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -284,6 +296,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -611,14 +626,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -631,13 +646,15 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -654,7 +671,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -681,7 +698,8 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -984,6 +1002,9 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) @@ -1317,10 +1338,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device @@ -1355,7 +1376,249 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + self.partially_unload_ram(1e32) + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = "{}.{}".format(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = "{}.{}".format(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index e406ba7ed..c3a1825ce 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,10 +19,16 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json +import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -72,7 +78,115 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + cast_dest = None + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + else: + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_source = [ pin ] + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. @@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) @@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -149,6 +286,57 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + super().__init__(in_features, out_features, bias, device, dtype) + return + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None @@ -655,8 +843,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -666,6 +854,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec input_shape = input.shape reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype if (getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and @@ -684,7 +874,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec scale = comfy.model_management.cast_to_device(scale, input.device, None) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) - output = self.forward_comfy_cast_weights(input) + + output = self.forward_comfy_cast_weights(input, compute_dtype) # Reshape output back to 3D if input was 3D if reshaped_3d: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..0650e4d1a --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,30 @@ +import torch +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) + size = comfy.memory_management.vram_aligned_size(params) + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + del module._pin + return size diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..8b9782956 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers @@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index f627f7d55..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -228,8 +228,10 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -389,8 +391,18 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): @@ -765,12 +777,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) + model_management.archive_model_dtypes(self.first_stage_model) if device is None: device = model_management.vae_device() @@ -782,7 +789,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -897,7 +915,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -971,7 +989,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None @@ -1432,7 +1450,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1520,7 +1538,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1563,7 +1582,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1655,13 +1673,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) @@ -1692,9 +1711,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index d4f22120b..9ecfc9c55 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e49161964..26573fb12 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module): for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: - missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing_all.extend([f"{prefix}{k}" for k in missing]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) diff --git a/comfy/utils.py b/comfy/utils.py index d97d753e6..9e98eb176 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,9 +28,10 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json import time +import mmap MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -56,21 +57,67 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, +} + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0] @@ -1308,3 +1355,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000..213dc481d --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 82c4754a3..176e6bc2f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -267,9 +267,9 @@ class ModelPatchLoader: device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..b2182df37 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,8 +87,14 @@ if not args.cuda_malloc: except: pass +if enables_dynamic_vram() and comfy_aimdo.control.init(): + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" diff --git a/execution.py b/execution.py index 4b4f63c80..93fafc4a2 100644 --- a/execution.py +++ b/execution.py @@ -9,9 +9,11 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + comfy.model_management.reset_cast_buffers() + torch.cuda.synchronize() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) diff --git a/main.py b/main.py index 37b06c1fa..b8c951375 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import os import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger from app.assets.scanner import seed_assets import itertools @@ -173,6 +173,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + import comfy.utils import execution @@ -184,6 +185,33 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +import comfy_aimdo.control +import comfy_aimdo.torch + +if enables_dynamic_vram(): + if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + logging.info("DynamicVRAM support detected and enabled") + else: + logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) diff --git a/requirements.txt b/requirements.txt index 4ac94cb16..be0bc537e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,6 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 +comfy-aimdo>=0.1.6 requests #non essential dependencies: From 32621c6a11f6ce1c0dd46aeb08e35dd64ec804f0 Mon Sep 17 00:00:00 2001 From: Christian Byrne Date: Sat, 31 Jan 2026 22:13:48 -0800 Subject: [PATCH 4/6] fix: improve error message when node type is missing (#12194) - Change error type from 'invalid_prompt' to 'missing_node_type' for frontend detection - Add extra_info with node_id, class_type, and node_title (from _meta.title) - Improve user-facing message: 'Node X not found. The custom node may not be installed.' --- execution.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/execution.py b/execution.py index 93fafc4a2..3dbab82e6 100644 --- a/execution.py +++ b/execution.py @@ -1014,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ outputs = set() for x in prompt: if 'class_type' not in prompt[x]: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title') error = { - "type": "invalid_prompt", - "message": "Cannot execute because a node is missing the class_type property.", + "type": "missing_node_type", + "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": None, + "node_title": node_title + } } return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title', class_type) error = { - "type": "invalid_prompt", - "message": f"Cannot execute because node {class_type} does not exist.", + "type": "missing_node_type", + "message": f"Node '{node_title}' not found. The custom node may not be installed.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": class_type, + "node_title": node_title + } } return (False, error, [], {}) From 667a1b8878187f7a5c8955b272e33c16e4c46bb7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 31 Jan 2026 22:55:18 -0800 Subject: [PATCH 5/6] Fix some custom nodes breaking. (#12203) --- comfy/model_patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 57b53d8c5..c8e6f088f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -111,6 +111,10 @@ def move_weight_functions(m, device): memory += f.move_to(device=device) return memory +def string_to_seed(data): + logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils") + return comfy.utils.string_to_seed(data) + class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key From 361b9a82a3e2445fc22e352df187138b0c8f67fb Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Sun, 1 Feb 2026 08:42:32 -0800 Subject: [PATCH 6/6] fix pinning with model defined dtype (#12208) pinned memory was converted back to pinning the CPU side weight without any changes. Fix the pinner to use the CPU weight and not the model defined geometry. This will either save RAM or stop buffer overruns when the types mismatch. Fix the model defined weight caster to use the [ s.weight, s.bias ] interpretation, as xfer_dest might be the flattened pin now. Fix the detection of needing to cast to not be conditional on !pin. --- comfy/ops.py | 22 +++++++++++----------- comfy/pinned_memory.py | 3 +-- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/comfy/ops.py b/comfy/ops.py index c3a1825ce..53c5e4dc3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -96,16 +96,16 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu pin = comfy.pinned_memory.get_pin(s) if pin is not None: xfer_source = [ pin ] - else: - for data, geometry in zip([ s.weight, s.bias ], cast_geometry): - if data is None: - continue - if data.dtype != geometry.dtype: - cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) - xfer_dest = None - break + + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) offload_stream = comfy.model_management.get_offload_stream(device) @@ -132,7 +132,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu comfy.model_management.sync_stream(device, offload_stream) if cast_dest is not None: - for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like(xfer_source, xfer_dest), + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 0650e4d1a..8acc327a7 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -11,8 +11,7 @@ def pin_memory(module): if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: return #FIXME: This is a RAM cache trigger event - params = comfy.memory_management.tensors_to_geometries([ module.weight, module.bias ]) - size = comfy.memory_management.vram_aligned_size(params) + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) pin = torch.empty((size,), dtype=torch.uint8) if comfy.model_management.pin_memory(pin): module._pin = pin