diff --git a/comfy/model_management.py b/comfy/model_management.py index 40f72fa1b..ca4318a45 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1238,6 +1238,8 @@ def reset_cast_buffers(): model = loaded_model.model if model is not None and model.is_dynamic(): model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), []) STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 41375c80e..d18f76456 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher): self.model.dynamic_pins = {} if self.load_device not in self.model.dynamic_pins: self.model.dynamic_pins[self.load_device] = { - "hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), - "stack": [], + "weights": (comfy_aimdo.host_buffer.HostBuffer(0), []), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0), []), "failed": False, "active": False, } @@ -1703,25 +1703,26 @@ class ModelPatcherDynamic(ModelPatcher): return freed def pinned_memory_size(self): - return self.model.dynamic_pins[self.load_device]["hostbuf"].size + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): freed = 0 pin_state = self.model.dynamic_pins[self.load_device] - hostbuf = pin_state["hostbuf"] - stack = pin_state["stack"] - while len(stack) > 0: - module, offset = stack.pop() - size = module._pin.numel() * module._pin.element_size() - del module._pin - hostbuf.truncate(offset) - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - freed += size - ram_to_unload -= size - if ram_to_unload <= 0: - return freed + for subset in subsets: + hostbuf, stack = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset) + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index a35759aad..208c777f8 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -5,16 +5,17 @@ import comfy_aimdo.torch from comfy.cli_args import args -def get_pin(module): +def get_pin(module, subset="weights"): return getattr(module, "_pin", None) -def pin_memory(module): +def pin_memory(module, subset="weights", size=None): pin_state = module._pin_state - if pin_state["failed"] or args.disable_pinned_memory or get_pin(module) is not None: + if pin_state["failed"] or args.disable_pinned_memory or get_pin(module, subset) is not None: return - hostbuf = pin_state["hostbuf"] - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + hostbuf, stack = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) offset = hostbuf.size comfy.model_management.ensure_pin_budget(size) @@ -26,6 +27,6 @@ def pin_memory(module): module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] module._pin.untyped_storage()._comfy_hostbuf = hostbuf - pin_state["stack"].append((module, offset)) + stack.append((module, offset)) comfy.model_management.TOTAL_PINNED_MEMORY += size return True