mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
implement pinned loras
This commit is contained in:
parent
eabf34c453
commit
63c27ed9e8
@ -524,7 +524,7 @@ def ensure_pin_budget(size, evict_active=False):
|
||||
shortfall += PIN_PRESSURE_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]):
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
shortfall -= model.partially_unload_ram(shortfall)
|
||||
if shortfall <= 0:
|
||||
break
|
||||
@ -1237,9 +1237,9 @@ def reset_cast_buffers():
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.dynamic_pins[model.load_device]["active"] = False
|
||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), [])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [])
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
|
||||
@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.model.dynamic_pins = {}
|
||||
if self.load_device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[self.load_device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0), []),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0), []),
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []),
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
@ -1602,7 +1602,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
lowvram_patch = LowVramPatch(key, self.patches)
|
||||
lowvram_patch._pin_state = pin_state
|
||||
setattr(m, param_key + "_lowvram_function", lowvram_patch)
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
|
||||
63
comfy/ops.py
63
comfy/ops.py
@ -183,34 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if pin is None:
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(dest_size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((xfer_source, pin_offset, dest_size, xfer_dest))
|
||||
xfer_source = None
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
|
||||
if xfer_source is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
def handle_pin_miss(m, source, dest, subset="weights", size=None):
|
||||
pin = None
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
return [ pin ]
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return None
|
||||
return source
|
||||
|
||||
if pin is None:
|
||||
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_source is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_size = lowvram_fn.memory_required()
|
||||
lowvram_size = lowvram_source.memory_required()
|
||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||
lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
pin_offset = get_stream_pin_buffer_offset(lowvram_size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest))
|
||||
else:
|
||||
lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True)
|
||||
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
|
||||
|
||||
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
@ -218,12 +229,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user