From 663aaf8b969ae2018933e7f689fe6b9fe423e296 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 7 May 2026 19:55:41 +1000 Subject: [PATCH] Implement JIT pinned memory pressure Replace the predictive pin pressure mechanism with JIT PIN memory pressure. --- comfy/model_management.py | 52 ++++++++++++++++++++++++++++++++++----- comfy/model_patcher.py | 17 ++++++++++--- comfy/ops.py | 8 ++++-- comfy/pinned_memory.py | 4 +-- 4 files changed, 66 insertions(+), 15 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 4b96d1492..6a2126cb5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -498,6 +498,8 @@ current_loaded_models = [] DIRTY_MMAPS = set() +PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -511,6 +513,21 @@ def mark_mmap_dirty(storage): if mmap_refs is not None: DIRTY_MMAPS.add(mmap_refs[0]) +def ensure_pin_budget(size, evict_active=False): + if MAX_PINNED_MEMORY <= 0: + return + + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if shortfall <= 0: + return + + shortfall += PIN_PRESSURE_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.partially_unload_ram(shortfall) + if shortfall <= 0: + break class LoadedModel: def __init__(self, model): @@ -1133,7 +1150,6 @@ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 -DEFAULT_PIN_BUFFER_PRIME_SIZE = 1024 ** 2 def get_cast_buffer(offload_stream, device, size, ref): global LARGEST_CASTED_WEIGHT @@ -1177,14 +1193,29 @@ def get_aimdo_cast_buffer(offload_stream, device): def get_pin_buffer(offload_stream): pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) if pin_buffer is None: - # A small non-zero default primes HostBuffer's larger virtual reservation. - pin_buffer = comfy_aimdo.host_buffer.HostBuffer(DEFAULT_PIN_BUFFER_PRIME_SIZE) + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0) STREAM_PIN_BUFFERS[offload_stream] = pin_buffer elif offload_stream is not None: offload_stream.synchronize() return pin_buffer +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0) + if size <= old_size: + return True + growth = size - old_size + ensure_pin_budget(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + pin_buffer._comfy_stream_pin_size = size + TOTAL_PINNED_MEMORY += growth + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT @@ -1197,8 +1228,18 @@ def reset_cast_buffers(): for mmap_obj in DIRTY_MMAPS: mmap_obj.bounce() - DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0) + if TOTAL_PINNED_MEMORY < 0: + TOTAL_PINNED_MEMORY = 0 + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.dynamic_pins[model.load_device]["active"] = False + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_PIN_BUFFERS.clear() @@ -1344,8 +1385,7 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + ensure_pin_budget(size) ptr = tensor.data_ptr() if ptr == 0: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index cfb540eb2..0b590138d 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1085,7 +1085,7 @@ class ModelPatcher: return 0 def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1495,7 +1495,12 @@ class ModelPatcherDynamic(ModelPatcher): if not hasattr(self.model, "dynamic_pins"): self.model.dynamic_pins = {} if self.load_device not in self.model.dynamic_pins: - self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False} + self.model.dynamic_pins[self.load_device] = { + "hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), + "stack": [], + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1559,6 +1564,7 @@ class ModelPatcherDynamic(ModelPatcher): vbar = self._vbar_get(create=True) pin_state = self.model.dynamic_pins[self.load_device] pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1688,9 +1694,10 @@ class ModelPatcherDynamic(ModelPatcher): return self.model.dynamic_pins[self.load_device]["hostbuf"].size def partially_unload_ram(self, ram_to_unload): + freed = 0 pin_state = self.model.dynamic_pins[self.load_device] hostbuf = pin_state["hostbuf"] - stack = self.model.dynamic_pins[self.load_device]["stack"] + stack = pin_state["stack"] while len(stack) > 0: module, offset = stack.pop() size = module._pin.numel() * module._pin.element_size() @@ -1699,9 +1706,11 @@ class ModelPatcherDynamic(ModelPatcher): comfy.model_management.TOTAL_PINNED_MEMORY -= size if comfy.model_management.TOTAL_PINNED_MEMORY < 0: comfy.model_management.TOTAL_PINNED_MEMORY = 0 + freed += size ram_to_unload -= size if ram_to_unload <= 0: - return + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index 2f9be9285..e060d9ba9 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -138,6 +138,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if stream_pin_hostbuf is None: stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None offset = stream_pin_offset stream_pin_offset += buffer_size @@ -213,8 +215,10 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size) if stream_pin_hostbuf_size < stream_pin_offset: stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM - stream_pin_hostbuf.extend(size=stream_pin_hostbuf_size, reallocate=True) - stream_pin_hostbuf._comfy_stream_pin_size = stream_pin_hostbuf_size + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + return offload_stream stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset) stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 3638066c8..a35759aad 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -16,9 +16,7 @@ def pin_memory(module): hostbuf = pin_state["hostbuf"] size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) offset = hostbuf.size - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - pin_state["failed"] = True - return False + comfy.model_management.ensure_pin_budget(size) try: hostbuf.extend(size=size)