mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 02:33:02 +08:00
Implement JIT pinned memory pressure
Replace the predictive pin pressure mechanism with JIT PIN memory pressure.
This commit is contained in:
parent
5262d9efce
commit
663aaf8b96
@ -498,6 +498,8 @@ current_loaded_models = []
|
|||||||
|
|
||||||
DIRTY_MMAPS = set()
|
DIRTY_MMAPS = set()
|
||||||
|
|
||||||
|
PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024
|
||||||
|
|
||||||
def module_size(module):
|
def module_size(module):
|
||||||
module_mem = 0
|
module_mem = 0
|
||||||
sd = module.state_dict()
|
sd = module.state_dict()
|
||||||
@ -511,6 +513,21 @@ def mark_mmap_dirty(storage):
|
|||||||
if mmap_refs is not None:
|
if mmap_refs is not None:
|
||||||
DIRTY_MMAPS.add(mmap_refs[0])
|
DIRTY_MMAPS.add(mmap_refs[0])
|
||||||
|
|
||||||
|
def ensure_pin_budget(size, evict_active=False):
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||||
|
if shortfall <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
shortfall += PIN_PRESSURE_HYSTERESIS
|
||||||
|
for loaded_model in reversed(current_loaded_models):
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]):
|
||||||
|
shortfall -= model.partially_unload_ram(shortfall)
|
||||||
|
if shortfall <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@ -1133,7 +1150,6 @@ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
|||||||
STREAM_PIN_BUFFERS = {}
|
STREAM_PIN_BUFFERS = {}
|
||||||
|
|
||||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||||
DEFAULT_PIN_BUFFER_PRIME_SIZE = 1024 ** 2
|
|
||||||
|
|
||||||
def get_cast_buffer(offload_stream, device, size, ref):
|
def get_cast_buffer(offload_stream, device, size, ref):
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
@ -1177,14 +1193,29 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
|||||||
def get_pin_buffer(offload_stream):
|
def get_pin_buffer(offload_stream):
|
||||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||||
if pin_buffer is None:
|
if pin_buffer is None:
|
||||||
# A small non-zero default primes HostBuffer's larger virtual reservation.
|
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0)
|
||||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(DEFAULT_PIN_BUFFER_PRIME_SIZE)
|
|
||||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||||
elif offload_stream is not None:
|
elif offload_stream is not None:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
return pin_buffer
|
return pin_buffer
|
||||||
|
|
||||||
|
def resize_pin_buffer(pin_buffer, size):
|
||||||
|
global TOTAL_PINNED_MEMORY
|
||||||
|
old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||||
|
if size <= old_size:
|
||||||
|
return True
|
||||||
|
growth = size - old_size
|
||||||
|
ensure_pin_budget(growth, evict_active=True)
|
||||||
|
try:
|
||||||
|
pin_buffer.extend(size=size, reallocate=True)
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
pin_buffer._comfy_stream_pin_size = size
|
||||||
|
TOTAL_PINNED_MEMORY += growth
|
||||||
|
return True
|
||||||
|
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
|
global TOTAL_PINNED_MEMORY
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
|
|
||||||
@ -1197,8 +1228,18 @@ def reset_cast_buffers():
|
|||||||
|
|
||||||
for mmap_obj in DIRTY_MMAPS:
|
for mmap_obj in DIRTY_MMAPS:
|
||||||
mmap_obj.bounce()
|
mmap_obj.bounce()
|
||||||
|
|
||||||
DIRTY_MMAPS.clear()
|
DIRTY_MMAPS.clear()
|
||||||
|
|
||||||
|
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||||
|
TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||||
|
if TOTAL_PINNED_MEMORY < 0:
|
||||||
|
TOTAL_PINNED_MEMORY = 0
|
||||||
|
|
||||||
|
for loaded_model in current_loaded_models:
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic():
|
||||||
|
model.dynamic_pins[model.load_device]["active"] = False
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
STREAM_PIN_BUFFERS.clear()
|
STREAM_PIN_BUFFERS.clear()
|
||||||
@ -1344,8 +1385,7 @@ def pin_memory(tensor):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
size = tensor.nbytes
|
size = tensor.nbytes
|
||||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
ensure_pin_budget(size)
|
||||||
return False
|
|
||||||
|
|
||||||
ptr = tensor.data_ptr()
|
ptr = tensor.data_ptr()
|
||||||
if ptr == 0:
|
if ptr == 0:
|
||||||
|
|||||||
@ -1085,7 +1085,7 @@ class ModelPatcher:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
pass
|
return 0
|
||||||
|
|
||||||
def detach(self, unpatch_all=True):
|
def detach(self, unpatch_all=True):
|
||||||
self.eject_model()
|
self.eject_model()
|
||||||
@ -1495,7 +1495,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if not hasattr(self.model, "dynamic_pins"):
|
if not hasattr(self.model, "dynamic_pins"):
|
||||||
self.model.dynamic_pins = {}
|
self.model.dynamic_pins = {}
|
||||||
if self.load_device not in self.model.dynamic_pins:
|
if self.load_device not in self.model.dynamic_pins:
|
||||||
self.model.dynamic_pins[self.load_device] = {"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0), "stack": [], "failed": False}
|
self.model.dynamic_pins[self.load_device] = {
|
||||||
|
"hostbuf": comfy_aimdo.host_buffer.HostBuffer(0),
|
||||||
|
"stack": [],
|
||||||
|
"failed": False,
|
||||||
|
"active": False,
|
||||||
|
}
|
||||||
self.non_dynamic_delegate_model = None
|
self.non_dynamic_delegate_model = None
|
||||||
assert load_device is not None
|
assert load_device is not None
|
||||||
|
|
||||||
@ -1559,6 +1564,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
vbar = self._vbar_get(create=True)
|
vbar = self._vbar_get(create=True)
|
||||||
pin_state = self.model.dynamic_pins[self.load_device]
|
pin_state = self.model.dynamic_pins[self.load_device]
|
||||||
pin_state["failed"] = False
|
pin_state["failed"] = False
|
||||||
|
pin_state["active"] = True
|
||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
@ -1688,9 +1694,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
|
return self.model.dynamic_pins[self.load_device]["hostbuf"].size
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
|
freed = 0
|
||||||
pin_state = self.model.dynamic_pins[self.load_device]
|
pin_state = self.model.dynamic_pins[self.load_device]
|
||||||
hostbuf = pin_state["hostbuf"]
|
hostbuf = pin_state["hostbuf"]
|
||||||
stack = self.model.dynamic_pins[self.load_device]["stack"]
|
stack = pin_state["stack"]
|
||||||
while len(stack) > 0:
|
while len(stack) > 0:
|
||||||
module, offset = stack.pop()
|
module, offset = stack.pop()
|
||||||
size = module._pin.numel() * module._pin.element_size()
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
@ -1699,9 +1706,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||||
|
freed += size
|
||||||
ram_to_unload -= size
|
ram_to_unload -= size
|
||||||
if ram_to_unload <= 0:
|
if ram_to_unload <= 0:
|
||||||
return
|
return freed
|
||||||
|
return freed
|
||||||
|
|
||||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||||
#This isn't used by the core at all and can only be to load a model out of
|
#This isn't used by the core at all and can only be to load a model out of
|
||||||
|
|||||||
@ -138,6 +138,8 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
|
|
||||||
if stream_pin_hostbuf is None:
|
if stream_pin_hostbuf is None:
|
||||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||||
|
if stream_pin_hostbuf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
offset = stream_pin_offset
|
offset = stream_pin_offset
|
||||||
stream_pin_offset += buffer_size
|
stream_pin_offset += buffer_size
|
||||||
@ -213,8 +215,10 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
|
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
|
||||||
if stream_pin_hostbuf_size < stream_pin_offset:
|
if stream_pin_hostbuf_size < stream_pin_offset:
|
||||||
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
|
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
|
||||||
stream_pin_hostbuf.extend(size=stream_pin_hostbuf_size, reallocate=True)
|
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size):
|
||||||
stream_pin_hostbuf._comfy_stream_pin_size = stream_pin_hostbuf_size
|
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||||
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
|
return offload_stream
|
||||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
|
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
|
||||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||||
|
|||||||
@ -16,9 +16,7 @@ def pin_memory(module):
|
|||||||
hostbuf = pin_state["hostbuf"]
|
hostbuf = pin_state["hostbuf"]
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
offset = hostbuf.size
|
offset = hostbuf.size
|
||||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
comfy.model_management.ensure_pin_budget(size)
|
||||||
pin_state["failed"] = True
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hostbuf.extend(size=size)
|
hostbuf.extend(size=size)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user