pinned_memory: implement with aimdo growable buffer

Use a single growable buffer so we can do threaded pre-warming on
pinned memory.
This commit is contained in:
Rattus 2026-05-04 12:32:12 +10:00
parent e5f84d0b64
commit 9712cdf305
2 changed files with 31 additions and 33 deletions

View File

@ -34,6 +34,7 @@ import comfy.model_management
import comfy.ops
import comfy.patcher_extension
import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -1494,6 +1495,10 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False}
self.non_dynamic_delegate_model = None
assert load_device is not None
@ -1555,6 +1560,8 @@ class ModelPatcherDynamic(ModelPatcher):
self.unpatch_hooks()
vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
pin_state["failed"] = False
if vbar is not None:
vbar.prioritize()
@ -1606,8 +1613,8 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight")
@ -1681,20 +1688,21 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
pin_state = self.model.dynamic_pins[self.load_device]
hostbuf = pin_state["hostbuf"]
stack = self.model.dynamic_pins[self.load_device]["stack"]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
ram_to_unload -= size
if ram_to_unload <= 0:
return

View File

@ -9,35 +9,25 @@ def get_pin(module):
return getattr(module, "_pin", None)
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
pin_state = module._pin_state
if pin_state["failed"] or args.disable_pinned_memory or get_pin(module) is not None:
return
hostbuf = pin_state["hostbuf"]
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module.pin_failed = True
pin_state["failed"] = True
return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
hostbuf.extend(size=size)
except RuntimeError:
module.pin_failed = True
pin_state["failed"] = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
pin_state["stack"].append((module, offset))
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size