diff --git a/comfy/model_management.py b/comfy/model_management.py index 758e718e8..cd035f017 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -19,7 +19,8 @@ import psutil import logging from enum import Enum -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram +import threading import torch import sys import platform @@ -650,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ soft_empty_cache() return unloaded_models -def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): +def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): cleanup_models_gc() global vram_state @@ -746,8 +747,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu current_loaded_models.insert(0, loaded_model) return -def load_model_gpu(model): - return load_models_gpu([model]) +def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load): + with torch.inference_mode(): + load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + soft_empty_cache() + +def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + #Deliberately load models outside of the Aimdo mempool so they can be retained accross + #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are + #thread local. So exploit that to escape context + if enables_dynamic_vram(): + t = threading.Thread( + target=load_models_gpu_thread, + args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load) + ) + t.start() + t.join() + else: + load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights, + minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) def loaded_models(only_currently_used=False): output = [] @@ -1112,11 +1130,11 @@ def get_cast_buffer(offload_stream, device, size, ref): return None if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now - torch.cuda.synchronize() + synchronize() del STREAM_CAST_BUFFERS[offload_stream] del cast_buffer #FIXME: This doesn't work in Aimdo because mempool cant clear cache - torch.cuda.empty_cache() + soft_empty_cache() with wf_context: cast_buffer = torch.empty((size), dtype=torch.int8, device=device) STREAM_CAST_BUFFERS[offload_stream] = cast_buffer @@ -1132,9 +1150,7 @@ def reset_cast_buffers(): for offload_stream in STREAM_CAST_BUFFERS: offload_stream.synchronize() STREAM_CAST_BUFFERS.clear() - if comfy.memory_management.aimdo_allocator is None: - #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist - torch.cuda.empty_cache() + soft_empty_cache() def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1202,27 +1218,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str assert r is None assert stream is None - r = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) + + if dtype is None: + dtype = weight._model_dtype + + r = torch.empty_like(weight, dtype=dtype, device=device) signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) if signature is not None: raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0] - - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + weight._v_signature = signature + #Send it over + v_tensor.copy_(weight, non_blocking=non_blocking) #always take a deep copy even if _v is good, as we have no reasonable point to unpin #a non comfy weight r.copy_(v_tensor) comfy_aimdo.model_vbar.vbar_unpin(weight._v) return r + if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: + #Offloaded casting could skip this, however it would make the quantizations + #inconsistent between loaded and offloaded weights. So force the double casting + #that would happen in regular flow to make offload deterministic. + cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_buffer.copy_(weight, non_blocking=non_blocking) + weight = cast_buffer r.copy_(weight, non_blocking=non_blocking) - if signature is not None: - weight._v_signature = signature - v_tensor.copy_(r) - comfy_aimdo.model_vbar.vbar_unpin(weight._v) - return r if device is None or weight.device == device: @@ -1275,7 +1300,7 @@ def discard_cuda_async_error(): a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) _ = a + b - torch.cuda.synchronize() + synchronize() except torch.AcceleratorError: #Dump it! We already know about it from the synchronous return pass @@ -1679,6 +1704,12 @@ def lora_compute_dtype(device): LORA_COMPUTE_DTYPES[device] = dtype return dtype +def synchronize(): + if is_intel_xpu(): + torch.xpu.synchronize() + elif torch.cuda.is_available(): + torch.cuda.synchronize() + def soft_empty_cache(force=False): global cpu_state if cpu_state == CPUState.MPS: @@ -1704,9 +1735,6 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -#TODO: might be cleaner to put this somewhere else -import threading - class InterruptProcessingException(Exception): pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8e6f088f..cdf289395 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -161,6 +161,11 @@ def get_key_weight(model, key): return weight, set_func, convert_func +def key_param_name_to_key(key, param): + if len(key) == 0: + return param + return "{}.{}".format(key, param) + class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model @@ -795,7 +800,7 @@ class ModelPatcher: continue for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) if comfy.model_management.is_device_cuda(device_to): @@ -811,7 +816,7 @@ class ModelPatcher: n = x[1] params = x[3] for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: @@ -917,7 +922,7 @@ class ModelPatcher: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: @@ -968,7 +973,7 @@ class ModelPatcher: logging.debug("freed {}".format(n)) for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) self.model.model_lowvram = True @@ -1501,7 +1506,7 @@ class ModelPatcherDynamic(ModelPatcher): def setup_param(self, m, n, param_key): nonlocal num_patches - key = "{}.{}".format(n, param_key) + key = key_param_name_to_key(n, param_key) weight_function = [] @@ -1540,7 +1545,7 @@ class ModelPatcherDynamic(ModelPatcher): else: for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) weight.seed_key = key set_dirty(weight, dirty) @@ -1592,7 +1597,7 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) - self.partially_unload(None) + self.partially_unload(None, 1e32) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py index 41f95bcb6..b6f58cb25 100644 --- a/comfy/text_encoders/anima.py +++ b/comfy/text_encoders/anima.py @@ -8,7 +8,7 @@ import torch class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index f67a5f805..1ae398789 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} @@ -176,12 +176,12 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) class KleinTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index ad41bfb1e..33b7cf594 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -6,7 +6,7 @@ import os class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class ZImageTokenizer(sd1_clip.SD1Tokenizer): diff --git a/comfy/utils.py b/comfy/utils.py index 9e98eb176..c1b536833 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -32,6 +32,7 @@ from comfy.cli_args import args, enables_dynamic_vram import json import time import mmap +import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -85,7 +86,10 @@ def load_safetensors(ckpt): header_size = struct.unpack("=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.1.6 +comfy-aimdo>=0.1.7 requests #non essential dependencies: