mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 17:22:36 +08:00
Implement JIT pinned memory pressure
Replace the predictive pin pressure mechanism with JIT PIN memory pressure.
This commit is contained in:
parent
5262d9efce
commit
663aaf8b96
@ -498,6 +498,8 @@ current_loaded_models = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@ -511,6 +513,21 @@ def mark_mmap_dirty(storage):
|
||||
if mmap_refs is not None:
|
||||
DIRTY_MMAPS.add(mmap_refs[0])
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return
|
||||
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
if shortfall <= 0:
|
||||
return
|
||||
|
||||
shortfall += PIN_PRESSURE_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]):
|
||||
shortfall -= model.partially_unload_ram(shortfall)
|
||||
if shortfall <= 0:
|
||||
break
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
@ -1133,7 +1150,6 @@ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_PIN_BUFFERS = {}
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
DEFAULT_PIN_BUFFER_PRIME_SIZE = 1024 ** 2
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1177,14 +1193,29 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
def get_pin_buffer(offload_stream):
|
||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||
if pin_buffer is None:
|
||||
# A small non-zero default primes HostBuffer's larger virtual reservation.
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(DEFAULT_PIN_BUFFER_PRIME_SIZE)
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0)
|
||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||
elif offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
return pin_buffer
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
ensure_pin_budget(growth, evict_active=True)
|
||||
try:
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
pin_buffer._comfy_stream_pin_size = size
|
||||
TOTAL_PINNED_MEMORY += growth
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
global TOTAL_PINNED_MEMORY
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
@ -1197,8 +1228,18 @@ def reset_cast_buffers():
|
||||
|
||||
for mmap_obj in DIRTY_MMAPS:
|
||||
mmap_obj.bounce()
|
||||
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||
if TOTAL_PINNED_MEMORY < 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.dynamic_pins[model.load_device]["active"] = False
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
STREAM_PIN_BUFFERS.clear()
|
||||
@ -1344,8 +1385,7 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
size = tensor.nbytes
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
ensure_pin_budget(size)
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
|
||||
@ -1085,7 +1085,7 @@ class ModelPatcher:
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
return 0
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
@ -1495,7 +1495,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if not hasattr(self.model, "dynamic_pins"):
|
||||
self.model.dynamic_pins = {}
|
||||
if self.load_device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False}
|
||||
self.model.dynamic_pins[self.load_device] = {
|
||||
"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0),
|
||||
"stack": [],
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
@ -1559,6 +1564,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
vbar = self._vbar_get(create=True)
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
@ -1688,9 +1694,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
hostbuf = pin_state["hostbuf"]
|
||||
stack = self.model.dynamic_pins[self.load_device]["stack"]
|
||||
stack = pin_state["stack"]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
@ -1699,9 +1706,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
return freed
|
||||
return freed
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
|
||||
@ -138,6 +138,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
|
||||
if stream_pin_hostbuf is None:
|
||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||
if stream_pin_hostbuf is None:
|
||||
return None
|
||||
|
||||
offset = stream_pin_offset
|
||||
stream_pin_offset += buffer_size
|
||||
@ -213,8 +215,10 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
|
||||
if stream_pin_hostbuf_size < stream_pin_offset:
|
||||
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
|
||||
stream_pin_hostbuf.extend(size=stream_pin_hostbuf_size, reallocate=True)
|
||||
stream_pin_hostbuf._comfy_stream_pin_size = stream_pin_hostbuf_size
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
|
||||
@ -16,9 +16,7 @@ def pin_memory(module):
|
||||
hostbuf = pin_state["hostbuf"]
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
comfy.model_management.ensure_pin_budget(size)
|
||||
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user