From 9712cdf305ca1334fedfac03ec67442a02bb8f02 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 4 May 2026 12:32:12 +1000 Subject: [PATCH] pinned_memory: implement with aimdo growable buffer Use a single growable buffer so we can do threaded pre-warming on pinned memory. --- comfy/model_patcher.py | 34 +++++++++++++++++++++------------- comfy/pinned_memory.py | 30 ++++++++++-------------------- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 7d2d6883f..0e3e002f3 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -34,6 +34,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -1494,6 +1495,10 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False} self.non_dynamic_delegate_model = None assert load_device is not None @@ -1555,6 +1560,8 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + pin_state["failed"] = False if vbar is not None: vbar.prioritize() @@ -1606,8 +1613,8 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) force_load, v_weight_size = setup_param(self, m, n, "weight") @@ -1681,20 +1688,21 @@ class ModelPatcherDynamic(ModelPatcher): return freed def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + return self.model.dynamic_pins[self.load_device]["hostbuf"].size def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + pin_state = self.model.dynamic_pins[self.load_device] + hostbuf = pin_state["hostbuf"] + stack = self.model.dynamic_pins[self.load_device]["stack"] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset) + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + ram_to_unload -= size if ram_to_unload <= 0: return diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..3638066c8 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -9,35 +9,25 @@ def get_pin(module): return getattr(module, "_pin", None) def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + pin_state = module._pin_state + if pin_state["failed"] or args.disable_pinned_memory or get_pin(module) is not None: return + hostbuf = pin_state["hostbuf"] size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) - + offset = hostbuf.size if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True + pin_state["failed"] = True return False try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + pin_state["stack"].append((module, offset)) comfy.model_management.TOTAL_PINNED_MEMORY += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size