prepare for multiple pin sets

This commit is contained in:
Rattus 2026-05-08 09:47:08 +10:00
parent 01eba77dba
commit eabf34c453
3 changed files with 28 additions and 24 deletions

View File

@ -1238,6 +1238,8 @@ def reset_cast_buffers():
model = loaded_model.model
if model is not None and model.is_dynamic():
model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), [])
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()

View File

@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0),
"stack": [],
"weights": (comfy_aimdo.host_buffer.HostBuffer(0), []),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0), []),
"failed": False,
"active": False,
}
@ -1703,25 +1703,26 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def pinned_memory_size(self):
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
def partially_unload_ram(self, ram_to_unload):
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
hostbuf = pin_state["hostbuf"]
stack = pin_state["stack"]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
for subset in subsets:
hostbuf, stack = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):

View File

@ -5,16 +5,17 @@ import comfy_aimdo.torch
from comfy.cli_args import args
def get_pin(module):
def get_pin(module, subset="weights"):
return getattr(module, "_pin", None)
def pin_memory(module):
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if pin_state["failed"] or args.disable_pinned_memory or get_pin(module) is not None:
if pin_state["failed"] or args.disable_pinned_memory or get_pin(module, subset) is not None:
return
hostbuf = pin_state["hostbuf"]
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
hostbuf, stack = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
comfy.model_management.ensure_pin_budget(size)
@ -26,6 +27,6 @@ def pin_memory(module):
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
pin_state["stack"].append((module, offset))
stack.append((module, offset))
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True