From 5aa5ccc9e02aec94cf43e0f71d4b2f62b204b5b6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 21 May 2026 10:03:58 +1000 Subject: [PATCH 1/3] Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802) * model_management: disable non-dynamic smart memory Disable smart memory outright for non dynamic models. This is a minor step towards deprecation of --disable-dynamic-vram and the legacy ModelPatcher. This is needed for estimate-free model development, where new models can opt-out of supplying a memory estimate and not have to worry about hard VRAM allocations due to legacy non-dynamic model patchers This is also a general stability increase for a lot of stray use cases where estimates may still be off and going forward we are not going to accurately maintain such estimates. * pinned_memory: implement with aimdo growable buffer Use a single growable buffer so we can do threaded pre-warming on pinned memory. * mm: use aimdo to do transfer from disk to pin Aimdo implements a faster threaded loader. * Add stream host pin buffer for AIMDO casts Introduce per-offload-stream HostBuffer reuse for pinned staging, include it in cast buffer reset synchronization. Defer actual casts that go via this pin path to a separate pass such that the buffer can be allocated monolithically (to avoid cudaHostRegister thrash). * remove old pin path * Implement JIT pinned memory pressure Replace the predictive pin pressure mechanism with JIT PIN memory pressure. * LowVRAMPatch: change to two-phase visit * lora: re-implement as inplace swiss-army-knife operation * prepare for multiple pin sets * implement pinned loras * requirements: comfy-aimdo 0.4.0 * ops: remove unused arg This was defeatured in aimdo iteration * ops: sync the CPU with only the offload stream activity This was syncing with the offload stream which itself is synced with the compute stream, so this was syncing CPU with compute transitively. Define the event to sync it more gently. * pins: implement freeing intermediate for pinned memory Pinning is more important than inactive intermediates and the stream pin buffer is more important than even active intermediates. * execution: implement pin eviction on RAM presure Add back proper pin freeing on RAM pressure * implement pin registration swaps Uncap the windows pins from 50% by extending the pool and have a pressure mechanism to move the pin reservations om demand. This unfortunately implies a GPU sync to do the freeing so significant hysterisis needs to be added to consolidate these pressure events. * cli_args/execution: Implement lower background cache-ram threshold Limit the amount of RAM background intermediates can use, so that switching workflows doesn't degrade performance too much. * make default * bump aimdo * model-patcher: force-cast tiny weights Flux 2 gets crazy stalls due to a mix of tiny and giant weights creating lopsided steam buffer rotations which creates stalls. * ops: refactor in prep for chunking * mm: delegate pin-on-the-way to aimdo Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault. * bump aimdo * pinning updates * specify hostbuf max allocation size There a signs of virtual memory exhaustion on some linux systems when throwing 128GB for every little piece. Pass the actual to save aimdo from over-estimates * tests: update execution tests for caching The default caching changed to ram-cache so update these tests accordingly. Remove the LRU 0 test as this also falls through to RAM cache. --- comfy/cli_args.py | 7 +- comfy/lora.py | 19 ++- comfy/memory_management.py | 24 +++- comfy/model_management.py | 189 +++++++++++++++++----------- comfy/model_patcher.py | 138 +++++++++++++++----- comfy/ops.py | 88 +++++++++++-- comfy/pinned_memory.py | 68 ++++++---- comfy/utils.py | 2 - comfy/windows.py | 52 -------- execution.py | 12 +- main.py | 20 +-- requirements.txt | 2 +- tests/execution/test_async_nodes.py | 3 +- tests/execution/test_execution.py | 3 +- 14 files changed, 408 insertions(+), 219 deletions(-) delete mode 100644 comfy/windows.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 76faed3ad..9d88c8517 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") -CACHE_RAM_AUTO_GB = -1.0 - cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") -cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -245,6 +243,9 @@ if comfy.options.args_parsing: else: args = parser.parse_args([]) +if args.cache_ram is not None and len(args.cache_ram) > 2: + parser.error("--cache-ram accepts at most two values: active GB and inactive GB") + if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/lora.py b/comfy/lora.py index f11e26ec9..c0e8b865c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if copy: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 48e3c11da..c43f0c4a2 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple): size: int -def read_tensor_file_slice_into(tensor, destination): +def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): if not isinstance(destination, QuantizedTensor): @@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination): if tensor._layout_cls != destination._layout_cls: return False - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + destination2=(destination2._qdata if destination2 is not None else None)): return False dst_orig_dtype = destination._params.orig_dtype destination._params.copy_from(tensor._params, non_blocking=False) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination2 is not None: + dst_orig_dtype = destination2._params.orig_dtype + destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) @@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination): if info.size == 0: return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) + if hostbuf is not None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + device_ptr = destination2.data_ptr() if destination2 is not None else 0 + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) + return True + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) @@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom): extra_ram_release_callback = callback RAM_CACHE_HEADROOM = max(0, int(headroom)) -def extra_ram_release(target): +def extra_ram_release(target, free_active=False): if extra_ram_release_callback is None: return 0 - return extra_ram_release_callback(target) + return extra_ram_release_callback(target, free_active=free_active) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..3894dfa9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -495,6 +496,14 @@ except: current_loaded_models = [] +DIRTY_MMAPS = set() + +PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024 + +#Freeing registerables on pressure does imply a GPU sync, so go big on +#the hysteresis so each expensive sync gives us back a good chunk. +REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -503,27 +512,46 @@ def module_size(module): module_mem += t.nbytes return module_mem -def module_mmap_residency(module, free=False): - mmap_touched_mem = 0 - module_mem = 0 - bounced_mmaps = set() - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() - if not getattr(storage, "_comfy_tensor_mmap_touched", False): - continue - mmap_touched_mem += t.nbytes - if not free: - continue - storage._comfy_tensor_mmap_touched = False - mmap_obj = storage._comfy_tensor_mmap_refs[0] - if mmap_obj in bounced_mmaps: - continue - mmap_obj.bounce() - bounced_mmaps.add(mmap_obj) - return mmap_touched_mem, module_mem +def mark_mmap_dirty(storage): + mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None) + if mmap_refs is not None: + DIRTY_MMAPS.add(mmap_refs[0]) + +def free_pins(size, evict_active=False): + freed_total = 0 + for loaded_model in reversed(current_loaded_models): + if size <= 0: + return freed_total + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + freed = model.partially_unload_ram(size) + freed_total += freed + size -= freed + return freed_total + +def ensure_pin_budget(size, evict_active=False): + shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if shortfall <= 0: + return True + + to_free = shortfall + PIN_PRESSURE_HYSTERESIS + return free_pins(to_free, evict_active=evict_active) >= shortfall + +def ensure_pin_registerable(size, evict_active=False): + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + if shortfall <= 0: + return True + + shortfall += REGISTERABLE_PIN_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True + return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: def __init__(self, model): @@ -553,9 +581,6 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_mmap_residency(self, free=False): - return self.model.model_mmap_residency(free=free) - def model_loaded_memory(self): return self.model.loaded_size() @@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: - import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 - def get_free_ram(): - return comfy.windows.get_free_ram() -else: - def get_free_ram(): - return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -657,7 +676,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() - comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY or device is None: + if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - pins_to_free = pins_required - get_free_ram() - if current_loaded_models[i].model.is_dynamic() and for_dynamic: + if for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) - - for x in can_unload_sorted: - i = x[-1] - ram_to_free = ram_required - psutil.virtual_memory().available - if ram_to_free <= 0 and i not in unloaded_model: - continue - resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) - if resident_memory > 0: - logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() - total_memory_required = {} - total_pins_required = {} - total_ram_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) - resident_memory, model_memory = loaded_model.model_mmap_residency() - pinned_memory = loaded_model.model.pinned_memory_size() - #FIXME: This can over-free the pins as it budgets to pin the entire model. We should - #make this JIT to keep as much pinned as possible. - pins_required = model_memory - pinned_memory - ram_required = model_memory - resident_memory - total_pins_required[device] = total_pins_required.get(device, 0) + pins_required - total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic, - pins_required=total_pins_required[device], - ram_required=total_ram_required[device]) + for_dynamic=free_for_dynamic) for device in total_memory_required: if device != torch.device("cpu"): @@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device): if cast_buffer is None: cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer - return cast_buffer + +def get_pin_buffer(offload_stream): + pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) + if pin_buffer is None: + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3)) + STREAM_PIN_BUFFERS[offload_stream] = pin_buffer + elif offload_stream is not None: + event = getattr(pin_buffer, "_comfy_event", None) + if event is not None: + event.synchronize() + delattr(pin_buffer, "_comfy_event") + return pin_buffer + +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = pin_buffer.size + if size <= old_size: + return True + growth = size - old_size + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_budget(growth, evict_active=True) + ensure_pin_registerable(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() + for mmap_obj in DIRTY_MMAPS: + mmap_obj.bounce() + DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= pin_buffer.size + TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() + STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1280,7 +1317,7 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) -def cast_to_gathered(tensors, r, non_blocking=False, stream=None): +def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): wf_context = nullcontext() if stream is not None: wf_context = stream @@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): wf_context = wf_context.as_context(stream) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: dest_view = dest_views.pop(0) + dest2_view = dest2_views.pop(0) if dest2_views is not None else None if tensor is None: continue - if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() - if hasattr(storage, "_comfy_tensor_mmap_touched"): - storage._comfy_tensor_mmap_touched = True + mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) + if dest2_view is not None: + dest2_view.copy_(dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): @@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: if is_nvidia() or is_amd(): + ram = get_total_memory(torch.device("cpu")) if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% + MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50% else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 + MAX_PINNED_MEMORY = ram * 0.90 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) +def pinned_hostbuf_size(size): + return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) + def discard_cuda_async_error(): try: a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) @@ -1378,8 +1422,8 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_registerable(size) ptr = tensor.data_ptr() if ptr == 0: @@ -1416,7 +1460,8 @@ def unpin_memory(tensor): return False if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: - TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + size = PINNED_MEMORY.pop(ptr) + TOTAL_PINNED_MEMORY -= size return True else: logging.warning("Unpin error.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4f9d8403e..c8ed02e70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -117,6 +118,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -124,11 +127,21 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) + return counter[0] + + def prepare(self, destination, stream, copy=True, commit=True): + counter = [0] + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None @@ -341,9 +354,6 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size - def model_mmap_residency(self, free=False): - return comfy.model_management.module_mmap_residency(self.model, free=free) - def loaded_size(self): return self.model.model_loaded_weight_memory @@ -1118,8 +1128,12 @@ class ModelPatcher: # Pinned memory pressure tracking is only implemented for DynamicVram loading return 0 + def loaded_ram_size(self): + # Loaded RAM pressure tracking is only implemented for DynamicVram loading + return 0 + def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = { + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "hostbufs_initialized": False, + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + if not pin_state["hostbufs_initialized"]: + hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["hostbufs_initialized"] = True + pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) @@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher): def force_load_param(self, param_key, device_to): key = key_param_name_to_key(n, param_key) + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to, force_cast=True) @@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) - force_load, v_weight_size = setup_param(self, m, n, "weight") - force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") - force_load = force_load or force_load_bias - v_weight_size += v_weight_bias + #Models that mix tiny and giant weights can causing lopsided stream buffer + #rotations and stall. force the tinys over. + if module_mem > 16 * 1024: + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + else: + force_load=True if force_load: - logging.info(f"Module {n} has resizing Lora - force loading") force_load_param(self, "weight", device_to) force_load_param(self, "bias", device_to) else: @@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher): return freed - def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + def loaded_ram_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) - if ram_to_unload <= 0: - return + def pinned_memory_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + + self.model.dynamic_pins[self.load_device]["patches"][3][0]) + + def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + split = stack_split[0] + while split >= 0: + module, offset = stack[split] + split -= 1 + stack_split[0] = split + if not module._pin_registered: + continue + size = module._pin.numel() * module._pin.element_size() + if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0: + comfy.model_management.discard_cuda_async_error() + continue + module._pin_registered = False + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed + + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset, do_unregister=module._pin_registered) + stack_split[0] = min(stack_split[0], len(stack) - 1) + if module._pin_registered: + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index eae3bd873..9bcd6c900 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -75,6 +75,8 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references +STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 + stream_pin_hostbuf = None + stream_pin_offset = 0 + stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer + def get_stream_pin_buffer_offset(buffer_size): + nonlocal stream_pin_hostbuf + nonlocal stream_pin_offset + + if buffer_size == 0 or offload_stream is None: + return None + + if stream_pin_hostbuf is None: + stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None + + offset = stream_pin_offset + stream_pin_offset += buffer_size + return offset + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if signature is None and pin is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - else: - pin = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - #send it over - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin(m, pin, source, dest, subset="weights", size=None): + if pin is not None: + cast_maybe_lowvram_patch([pin], dest, offload_stream) + return + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + if isinstance(source, list): + comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) + else: + cast_maybe_lowvram_patch(source, pin, None) + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + return + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return + cast_maybe_lowvram_patch(source, dest, offload_stream) + + handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_size = lowvram_source.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + if stream_pin_offset > 0: + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + return offload_stream + stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) + stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf + for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: + pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] + if isinstance(xfer_source, list): + comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) + else: + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + stream_pin_hostbuf._comfy_event = offload_stream.record_event() + return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..0e8f573ba 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,42 +2,62 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import torch from comfy.cli_args import args -def get_pin(module): - return getattr(module, "_pin", None) +def get_pin(module, subset="weights"): + pin = getattr(module, "_pin", None) + if pin is None or module._pin_registered or args.disable_pinned_memory: + return pin -def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + _, _, stack_split, pinned_size = module._pin_state[subset] + size = pin.nbytes + comfy.model_management.ensure_pin_registerable(size) + + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + return pin + + module._pin_registered = True + stack_split[0] = max(stack_split[0], module._pin_stack_index) + comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size + return pin + +def pin_memory(module, subset="weights", size=None): + pin_state = module._pin_state + if args.disable_pinned_memory: return - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + pin = get_pin(module, subset) + if pin is not None or pin_state["failed"]: + return - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + offset = hostbuf.size + registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + if (not comfy.model_management.ensure_pin_budget(size) or + not comfy.model_management.ensure_pin_registerable(registerable_size)): + pin_state["failed"] = True return False try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + stack.append((module, offset)) + module._pin_registered = True + module._pin_stack_index = len(stack) - 1 + stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size diff --git a/comfy/utils.py b/comfy/utils.py index 66682690a..00e382fac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -113,7 +113,6 @@ def load_safetensors(ckpt): "_comfy_tensor_file_slice", comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) - setattr(storage, "_comfy_tensor_mmap_touched", False) sd[name] = tensor return sd, header.get("__metadata__", {}), @@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res - diff --git a/comfy/windows.py b/comfy/windows.py deleted file mode 100644 index 213dc481d..000000000 --- a/comfy/windows.py +++ /dev/null @@ -1,52 +0,0 @@ -import ctypes -import logging -import psutil -from ctypes import wintypes - -import comfy_aimdo.control - -psapi = ctypes.WinDLL("psapi") -kernel32 = ctypes.WinDLL("kernel32") - -class PERFORMANCE_INFORMATION(ctypes.Structure): - _fields_ = [ - ("cb", wintypes.DWORD), - ("CommitTotal", ctypes.c_size_t), - ("CommitLimit", ctypes.c_size_t), - ("CommitPeak", ctypes.c_size_t), - ("PhysicalTotal", ctypes.c_size_t), - ("PhysicalAvailable", ctypes.c_size_t), - ("SystemCache", ctypes.c_size_t), - ("KernelTotal", ctypes.c_size_t), - ("KernelPaged", ctypes.c_size_t), - ("KernelNonpaged", ctypes.c_size_t), - ("PageSize", ctypes.c_size_t), - ("HandleCount", wintypes.DWORD), - ("ProcessCount", wintypes.DWORD), - ("ThreadCount", wintypes.DWORD), - ] - -def get_free_ram(): - #Windows is way too conservative and chalks recently used uncommitted model RAM - #as "in-use". So, calculate free RAM for the sake of general use as the greater of: - # - #1: What psutil says - #2: Total Memory - (Committed Memory - VRAM in use) - # - #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked - #commit charge for all VRAM used just incase it wants to page it all out. This just - #isn't realistic so "overcommit" on our calculations by just subtracting it off. - - pi = PERFORMANCE_INFORMATION() - pi.cb = ctypes.sizeof(pi) - - if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): - logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") - return psutil.virtual_memory().available - - committed = pi.CommitTotal * pi.PageSize - total = pi.PhysicalTotal * pi.PageSize - - return max(psutil.virtual_memory().available, - total - (committed - comfy_aimdo.control.get_total_vram_usage())) - diff --git a/execution.py b/execution.py index 4c7de2e84..5246d651c 100644 --- a/execution.py +++ b/execution.py @@ -2,6 +2,7 @@ import copy import heapq import inspect import logging +import psutil import sys import threading import time @@ -727,6 +728,7 @@ class PromptExecutor: self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) @@ -780,8 +782,14 @@ class PromptExecutor: execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: - comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) - ram_release_callback(ram_headroom, free_active=True) + ram_release_callback(ram_inactive_headroom) + ram_shortfall = ram_headroom - psutil.virtual_memory().available + freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + if freed < ram_shortfall: + if freed > 64 * (1024 ** 2): + # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. + time.sleep(0.05) + ram_release_callback(ram_headroom, free_active=True) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed diff --git a/main.py b/main.py index a6fdaf43c..1e47cab84 100644 --- a/main.py +++ b/main.py @@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: def prompt_worker(q, server_instance): current_time: float = 0.0 - cache_ram = args.cache_ram - if cache_ram < 0: + cache_ram = 0 + cache_ram_inactive = 0 + if not args.cache_classic and not args.cache_none and args.cache_lru <= 0: cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) + cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0)) + if len(args.cache_ram) > 0: + cache_ram = args.cache_ram[0] + if len(args.cache_ram) > 1: + cache_ram_inactive = args.cache_ram[1] - cache_type = execution.CacheType.CLASSIC - if args.cache_lru > 0: + cache_type = execution.CacheType.RAM_PRESSURE + if args.cache_classic: + cache_type = execution.CacheType.CLASSIC + elif args.cache_lru > 0: cache_type = execution.CacheType.LRU - elif cache_ram > 0: - cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/requirements.txt b/requirements.txt index 1c87690da..d2986eda8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.3.0 +comfy-aimdo==0.4.3 requests simpleeval>=1.0.0 blake3 diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py index c771b4b36..54660c112 100644 --- a/tests/execution/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup class TestAsyncNodes: @fixture(scope="class", autouse=True, params=[ (False, 0), - (True, 0), (True, 100), ]) def _server(self, args_pytest, request): @@ -29,6 +28,8 @@ class TestAsyncNodes: use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] + else: + pargs += ['--cache-classic'] # Running server with args: pargs p = subprocess.Popen(pargs) yield diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..15e2304fc 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -183,8 +183,7 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-classic"], "should_cache_results" : True }, { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) From 95fdc6cf910f809e39edc3254470e619ffa9dbf8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 17:17:55 -0700 Subject: [PATCH 2/3] Repo security stuff. (#14019) --- CODEOWNERS | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 946dbf946..043c0ec75 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,5 @@ -# Admins * @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai + +/CODEOWNERS @comfyanonymous +/.ci/ @comfyanonymous +/.github/ @comfyanonymous From 9f9b32ed978045262b71e6b27093e4ae80c29804 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 20 May 2026 21:22:12 -0700 Subject: [PATCH 3/3] feat: add OAuth 2.1 + RFC 7591 DCR endpoints to openapi.yaml (#14026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the OAuth 2.1 authorization flow and RFC 7591 Dynamic Client Registration endpoints to the shared spec, alongside the existing auth-tagged operations (/api/auth/session, /api/auth/token, /.well-known/jwks.json). All tagged x-runtime: [cloud] with a [cloud-only] description prefix, following the established convention for cloud-runtime-only operations. Endpoints: - GET /.well-known/oauth-authorization-server (RFC 8414 metadata) - GET /.well-known/oauth-protected-resource (RFC 9728 metadata) - GET /oauth/authorize (consent challenge) - POST /oauth/authorize (consent submission) - POST /oauth/token (RFC 6749 §3.2) - POST /oauth/register (RFC 7591 §3.1 DCR) Component schemas added: - OAuthAuthorizationServerMetadata - OAuthProtectedResourceMetadata - OAuthConsentChallenge, OAuthConsentChallengeWorkspace - OAuthAuthorizeRedirectResponse - OAuthTokenResponse, OAuthTokenError - OAuthRegisterRequest, OAuthRegisterResponse, OAuthRegisterError These endpoints are implemented in the cloud runtime today and are called by browser frontends rendering the consent UI and by MCP-spec-compliant clients (Claude Desktop, Cursor, etc.) doing auto-discovery + self-registration. Documenting them in the shared spec lets the cloud frontend generate types directly from this spec instead of maintaining a parallel definition. Spectral lints clean (0 errors). The hint-level findings on OAuthTokenError / OAuthRegisterError ("standard error schema") match the same hint on CloudError — these are protocol-specific RFC-shaped errors, not generic application errors. --- openapi.yaml | 608 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 608 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 2658b9b86..92f7eaccc 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3790,6 +3790,295 @@ paths: schema: $ref: "#/components/schemas/JwksResponse" + # --------------------------------------------------------------------------- + # OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud) + # --------------------------------------------------------------------------- + /.well-known/oauth-authorization-server: + get: + operationId: getOAuthAuthorizationServer + tags: [auth] + summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)" + description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes." + x-runtime: [cloud] + security: [] + responses: + "200": + description: Authorization-server metadata + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthAuthorizationServerMetadata" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /.well-known/oauth-protected-resource: + get: + operationId: getOAuthProtectedResource + tags: [auth] + summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)" + description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes." + x-runtime: [cloud] + security: [] + responses: + "200": + description: Protected-resource metadata + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthProtectedResourceMetadata" + "404": + description: OAuth disabled or no active resource configured + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/authorize: + get: + operationId: getOAuthAuthorize + tags: [auth] + summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request" + description: | + [cloud-only] Two modes: + - **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render. + - **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored. + + The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint. + x-runtime: [cloud] + security: [] + parameters: + - { name: response_type, in: query, required: false, schema: { type: string } } + - { name: client_id, in: query, required: false, schema: { type: string } } + - { name: redirect_uri, in: query, required: false, schema: { type: string } } + - { name: scope, in: query, required: false, schema: { type: string } } + - name: state + in: query + required: false + schema: { type: string } + description: | + RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400. + - { name: code_challenge, in: query, required: false, schema: { type: string } } + - { name: code_challenge_method, in: query, required: false, schema: { type: string } } + - { name: resource, in: query, required: false, schema: { type: string } } + - { name: oauth_request_id, in: query, required: false, schema: { type: string } } + responses: + "200": + description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize. + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthConsentChallenge" + "302": + description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error) + headers: + Location: + schema: + type: string + "400": + description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params) + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: postOAuthAuthorize + tags: [auth] + summary: "[cloud-only] Submit OAuth consent decision" + description: | + [cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny. + + Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow. + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [oauth_request_id, csrf_token, decision, workspace_id] + properties: + oauth_request_id: { type: string, format: uuid } + csrf_token: { type: string } + decision: { type: string, enum: [allow, deny] } + workspace_id: { type: string } + responses: + "200": + description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state) + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthAuthorizeRedirectResponse" + "400": + description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace) + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Scope broadening on consent re-grant — fresh consent flow required + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/token: + post: + operationId: postOAuthToken + tags: [auth] + summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token" + description: | + [cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected. + + Two grant types are supported: + - `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed. + - `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric. + + Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed. + + Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens. + + Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses. + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/x-www-form-urlencoded: + schema: + type: object + required: [grant_type, client_id] + properties: + grant_type: { type: string, enum: [authorization_code, refresh_token] } + client_id: { type: string } + code: { type: string } + redirect_uri: { type: string } + code_verifier: { type: string } + refresh_token: { type: string } + scope: { type: string } + client_secret: { type: string } + responses: + "200": + description: New token pair + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store" per RFC 6749 §5.1' + Pragma: + schema: + type: string + description: 'Always "no-cache" per RFC 6749 §5.1' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthTokenResponse" + "400": + description: RFC 6749 §5.2 error + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store" per RFC 6749 §5.1' + Pragma: + schema: + type: string + description: 'Always "no-cache" per RFC 6749 §5.1' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthTokenError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/register: + post: + operationId: postOAuthRegister + tags: [auth] + summary: "[cloud-only] Dynamic Client Registration (RFC 7591)" + description: | + [cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement. + + Policy: + + - Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase. + - Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes. + - Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`. + - Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare. + - Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs. + - Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons). + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterRequest" + responses: + "201": + description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires. + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store"' + Pragma: + schema: + type: string + description: 'Always "no-cache"' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterResponse" + "400": + description: RFC 7591 §3.2.2 invalid client metadata + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store"' + Pragma: + schema: + type: string + description: 'Always "no-cache"' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "503": + description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded. + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + # --------------------------------------------------------------------------- # Billing (cloud) # --------------------------------------------------------------------------- @@ -7424,6 +7713,325 @@ components: description: RSA exponent (base64url) additionalProperties: true + OAuthAuthorizationServerMetadata: + type: object + x-runtime: [cloud] + description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)." + required: + - issuer + - authorization_endpoint + - token_endpoint + - jwks_uri + - response_types_supported + - grant_types_supported + - code_challenge_methods_supported + - token_endpoint_auth_methods_supported + properties: + issuer: + type: string + format: uri + authorization_endpoint: + type: string + format: uri + token_endpoint: + type: string + format: uri + jwks_uri: + type: string + format: uri + registration_endpoint: + type: string + format: uri + description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled." + response_types_supported: + type: array + items: + type: string + grant_types_supported: + type: array + items: + type: string + code_challenge_methods_supported: + type: array + items: + type: string + token_endpoint_auth_methods_supported: + type: array + items: + type: string + scopes_supported: + type: array + items: + type: string + + OAuthProtectedResourceMetadata: + type: object + x-runtime: [cloud] + description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)." + required: + - resource + - authorization_servers + - scopes_supported + properties: + resource: + type: string + format: uri + authorization_servers: + type: array + items: + type: string + format: uri + scopes_supported: + type: array + items: + type: string + bearer_methods_supported: + type: array + items: + type: string + + OAuthConsentChallenge: + type: object + x-runtime: [cloud] + description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume." + required: + - oauth_request_id + - csrf_token + - client_display_name + - resource_display_name + - scopes + - workspaces + properties: + oauth_request_id: + type: string + format: uuid + description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission. + csrf_token: + type: string + description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST. + client_display_name: + type: string + description: Human-readable name of the OAuth client requesting authorization. + resource_display_name: + type: string + description: Human-readable name of the protected resource. + scopes: + type: array + description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve. + items: + type: string + workspaces: + type: array + description: Workspaces the user can select from. Membership is re-checked on POST. + items: + $ref: "#/components/schemas/OAuthConsentChallengeWorkspace" + + OAuthConsentChallengeWorkspace: + type: object + x-runtime: [cloud] + description: "[cloud-only] One workspace option presented in the OAuth consent challenge." + required: [id, name, type, role] + properties: + id: { type: string } + name: { type: string } + type: { type: string, enum: [personal, team] } + role: { type: string, enum: [owner, member] } + + OAuthAuthorizeRedirectResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers." + required: + - redirect_url + properties: + redirect_url: + type: string + format: uri + description: OAuth client redirect URI with either code+state for allow, or error+state for deny. + + OAuthTokenResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 6749 §5.1 successful token response." + required: [access_token, token_type, expires_in, refresh_token, scope] + properties: + access_token: + type: string + description: Resource-bound access token (audience matches the protected resource). + token_type: + type: string + enum: [Bearer] + expires_in: + type: integer + description: Access token lifetime in seconds. + refresh_token: + type: string + description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family. + scope: + type: string + description: Space-delimited scopes granted with this token. + + OAuthTokenError: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 6749 §5.2 error response." + required: [error] + properties: + error: + type: string + description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.' + error_description: + type: string + description: Human-readable, no leak of internal storage state. + + OAuthRegisterRequest: + type: object + x-runtime: [cloud] + additionalProperties: false + description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients." + required: + - redirect_uris + - application_type + properties: + redirect_uris: + type: array + items: + type: string + minItems: 1 + maxItems: 5 + description: 1–5 redirect URIs. Validated against `application_type` policy. + client_name: + type: string + maxLength: 100 + description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients. + application_type: + type: string + enum: [native, web] + description: | + RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`. + token_endpoint_auth_method: + type: string + enum: [none] + description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.' + grant_types: + type: array + items: + type: string + enum: [authorization_code, refresh_token] + description: Optional. Defaults to `["authorization_code","refresh_token"]`. + response_types: + type: array + items: + type: string + enum: [code] + description: Optional. Defaults to `["code"]`. + scope: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`." + resource_grants: + type: object + nullable: true + additionalProperties: + type: array + items: + type: string + description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven." + client_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + logo_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + tos_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + policy_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + software_id: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + software_version: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + contacts: + type: array + nullable: true + items: + type: string + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + jwks: + type: object + nullable: true + additionalProperties: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + jwks_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + + OAuthRegisterResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 7591 §3.2.1 successful registration response." + required: + - client_id + - client_id_issued_at + - redirect_uris + - grant_types + - response_types + - token_endpoint_auth_method + - application_type + properties: + client_id: + type: string + description: Server-generated client_id. + client_id_issued_at: + type: integer + format: int64 + description: Unix timestamp (seconds) when the client was registered. + client_name: + type: string + redirect_uris: + type: array + items: + type: string + grant_types: + type: array + items: + type: string + response_types: + type: array + items: + type: string + token_endpoint_auth_method: + type: string + enum: [none] + application_type: + type: string + enum: [native, web] + + OAuthRegisterError: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 7591 §3.2.2 error response." + required: + - error + properties: + error: + type: string + enum: [invalid_redirect_uri, invalid_client_metadata] + error_description: + type: string + nullable: true + BillingBalance: type: object x-runtime: [cloud]