implement pinned loras

This commit is contained in:
Rattus 2026-05-08 23:28:43 +10:00
parent eabf34c453
commit 63c27ed9e8
3 changed files with 42 additions and 35 deletions

View File

@ -524,7 +524,7 @@ def ensure_pin_budget(size, evict_active=False):
shortfall += PIN_PRESSURE_HYSTERESIS shortfall += PIN_PRESSURE_HYSTERESIS
for loaded_model in reversed(current_loaded_models): for loaded_model in reversed(current_loaded_models):
model = loaded_model.model model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]): if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.partially_unload_ram(shortfall) shortfall -= model.partially_unload_ram(shortfall)
if shortfall <= 0: if shortfall <= 0:
break break
@ -1237,9 +1237,9 @@ def reset_cast_buffers():
for loaded_model in current_loaded_models: for loaded_model in current_loaded_models:
model = loaded_model.model model = loaded_model.model
if model is not None and model.is_dynamic(): if model is not None and model.is_dynamic():
model.dynamic_pins[model.load_device]["active"] = False model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ]) model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), []) model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [])
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear()

View File

@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.dynamic_pins = {} self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins: if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = { self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0), []), "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0), []), "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []),
"failed": False, "failed": False,
"active": False, "active": False,
} }
@ -1602,7 +1602,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0) return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1 num_patches += 1
else: else:
setattr(m, param_key + "_lowvram_function", None) setattr(m, param_key + "_lowvram_function", None)

View File

@ -183,34 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None: if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size) xfer_dest = get_cast_buffer(dest_size)
if pin is None: def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if signature is None: if xfer_source is not None:
comfy.pinned_memory.pin_memory(s) if getattr(xfer_source, "is_lowvram_patch", False):
pin = comfy.pinned_memory.get_pin(s) xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
if pin is not None: else:
comfy.model_management.cast_to_gathered(xfer_source, pin) comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
xfer_source = [ pin ]
if pin is None:
pin_offset = get_stream_pin_buffer_offset(dest_size)
if pin_offset is not None:
stream_pin_queue.append((xfer_source, pin_offset, dest_size, xfer_dest))
xfer_source = None
if xfer_source is not None: def handle_pin_miss(m, source, dest, subset="weights", size=None):
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) pin = None
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
cast_maybe_lowvram_patch(source, pin, None)
return [ pin ]
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return None
return source
if pin is None:
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
for param_key in ("weight", "bias"): for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None: if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False) ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_size = lowvram_fn.memory_required() lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size) lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True) lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin_offset = get_stream_pin_buffer_offset(lowvram_size)
if pin_offset is not None: pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest)) lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
else:
lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True) cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
prefetch["xfer_dest"] = xfer_dest prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest prefetch["cast_dest"] = cast_dest
@ -218,12 +229,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch s._prefetch = prefetch
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if stream_pin_offset > 0: if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset: if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):