mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-11 01:32:31 +08:00
implement pinned loras
This commit is contained in:
parent
eabf34c453
commit
63c27ed9e8
@ -524,7 +524,7 @@ def ensure_pin_budget(size, evict_active=False):
|
|||||||
shortfall += PIN_PRESSURE_HYSTERESIS
|
shortfall += PIN_PRESSURE_HYSTERESIS
|
||||||
for loaded_model in reversed(current_loaded_models):
|
for loaded_model in reversed(current_loaded_models):
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
if model is not None and model.is_dynamic() and (evict_active or not model.dynamic_pins[model.load_device]["active"]):
|
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||||
shortfall -= model.partially_unload_ram(shortfall)
|
shortfall -= model.partially_unload_ram(shortfall)
|
||||||
if shortfall <= 0:
|
if shortfall <= 0:
|
||||||
break
|
break
|
||||||
@ -1237,9 +1237,9 @@ def reset_cast_buffers():
|
|||||||
for loaded_model in current_loaded_models:
|
for loaded_model in current_loaded_models:
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
if model is not None and model.is_dynamic():
|
if model is not None and model.is_dynamic():
|
||||||
model.dynamic_pins[model.load_device]["active"] = False
|
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||||
model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0), [])
|
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [])
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
|
|||||||
@ -1508,8 +1508,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
self.model.dynamic_pins = {}
|
self.model.dynamic_pins = {}
|
||||||
if self.load_device not in self.model.dynamic_pins:
|
if self.load_device not in self.model.dynamic_pins:
|
||||||
self.model.dynamic_pins[self.load_device] = {
|
self.model.dynamic_pins[self.load_device] = {
|
||||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0), []),
|
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []),
|
||||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0), []),
|
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []),
|
||||||
"failed": False,
|
"failed": False,
|
||||||
"active": False,
|
"active": False,
|
||||||
}
|
}
|
||||||
@ -1602,7 +1602,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if key in self.patches:
|
if key in self.patches:
|
||||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||||
return (True, 0)
|
return (True, 0)
|
||||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
lowvram_patch = LowVramPatch(key, self.patches)
|
||||||
|
lowvram_patch._pin_state = pin_state
|
||||||
|
setattr(m, param_key + "_lowvram_function", lowvram_patch)
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
else:
|
else:
|
||||||
setattr(m, param_key + "_lowvram_function", None)
|
setattr(m, param_key + "_lowvram_function", None)
|
||||||
|
|||||||
63
comfy/ops.py
63
comfy/ops.py
@ -183,34 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
if xfer_dest is None:
|
if xfer_dest is None:
|
||||||
xfer_dest = get_cast_buffer(dest_size)
|
xfer_dest = get_cast_buffer(dest_size)
|
||||||
|
|
||||||
if pin is None:
|
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||||
if signature is None:
|
if xfer_source is not None:
|
||||||
comfy.pinned_memory.pin_memory(s)
|
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||||
pin = comfy.pinned_memory.get_pin(s)
|
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||||
if pin is not None:
|
else:
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||||
xfer_source = [ pin ]
|
|
||||||
if pin is None:
|
|
||||||
pin_offset = get_stream_pin_buffer_offset(dest_size)
|
|
||||||
if pin_offset is not None:
|
|
||||||
stream_pin_queue.append((xfer_source, pin_offset, dest_size, xfer_dest))
|
|
||||||
xfer_source = None
|
|
||||||
|
|
||||||
if xfer_source is not None:
|
def handle_pin_miss(m, source, dest, subset="weights", size=None):
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
pin = None
|
||||||
|
if signature is None:
|
||||||
|
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||||
|
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||||
|
if pin is not None:
|
||||||
|
cast_maybe_lowvram_patch(source, pin, None)
|
||||||
|
return [ pin ]
|
||||||
|
if pin is None:
|
||||||
|
pin_offset = get_stream_pin_buffer_offset(size)
|
||||||
|
if pin_offset is not None:
|
||||||
|
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||||
|
return None
|
||||||
|
return source
|
||||||
|
|
||||||
|
if pin is None:
|
||||||
|
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
|
||||||
|
|
||||||
|
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||||
|
|
||||||
for param_key in ("weight", "bias"):
|
for param_key in ("weight", "bias"):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||||
if lowvram_fn is not None:
|
if lowvram_source is not None:
|
||||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||||
lowvram_size = lowvram_fn.memory_required()
|
lowvram_size = lowvram_source.memory_required()
|
||||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||||
lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True)
|
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||||
pin_offset = get_stream_pin_buffer_offset(lowvram_size)
|
|
||||||
if pin_offset is not None:
|
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||||
stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest))
|
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
|
||||||
else:
|
|
||||||
lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True)
|
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
|
||||||
|
|
||||||
prefetch["xfer_dest"] = xfer_dest
|
prefetch["xfer_dest"] = xfer_dest
|
||||||
prefetch["cast_dest"] = cast_dest
|
prefetch["cast_dest"] = cast_dest
|
||||||
@ -218,12 +229,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
prefetch["needs_cast"] = needs_cast
|
prefetch["needs_cast"] = needs_cast
|
||||||
s._prefetch = prefetch
|
s._prefetch = prefetch
|
||||||
|
|
||||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
|
||||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
|
||||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
|
||||||
else:
|
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
|
||||||
|
|
||||||
if stream_pin_offset > 0:
|
if stream_pin_offset > 0:
|
||||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user