From 01eba77dba09a1627b57fe24a5dde03f10bdf778 Mon Sep 17 00:00:00 2001 From: Rattus Date: Fri, 8 May 2026 08:19:47 +1000 Subject: [PATCH] lora: re-implement as inplace swiss-army-knife operation --- comfy/lora.py | 10 +++++----- comfy/model_management.py | 7 +++---- comfy/model_patcher.py | 13 +++++++++---- comfy/ops.py | 25 ++++++++++++++++++------- 4 files changed, 35 insertions(+), 20 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index f7c7c21a5..2b8699710 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -475,7 +475,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, counter, destination, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): size = comfy.memory_management.vram_aligned_size(value) offset = counter[0] @@ -484,14 +484,14 @@ def prefetch_prepared_value(value, counter, destination, stream): return value dest = destination[offset:offset + size] - if stream is not None: + if copy: comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, counter, destination, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/model_management.py b/comfy/model_management.py index 6a2126cb5..40f72fa1b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1201,7 +1201,7 @@ def get_pin_buffer(offload_stream): def resize_pin_buffer(pin_buffer, size): global TOTAL_PINNED_MEMORY - old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0) + old_size = pin_buffer.size if size <= old_size: return True growth = size - old_size @@ -1210,8 +1210,7 @@ def resize_pin_buffer(pin_buffer, size): pin_buffer.extend(size=size, reallocate=True) except RuntimeError: return False - pin_buffer._comfy_stream_pin_size = size - TOTAL_PINNED_MEMORY += growth + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size return True def reset_cast_buffers(): @@ -1231,7 +1230,7 @@ def reset_cast_buffers(): DIRTY_MMAPS.clear() for pin_buffer in STREAM_PIN_BUFFERS.values(): - TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0) + TOTAL_PINNED_MEMORY -= pin_buffer.size if TOTAL_PINNED_MEMORY < 0: TOTAL_PINNED_MEMORY = 0 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 995b1b36b..41375c80e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -117,6 +117,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -127,15 +129,18 @@ class LowVramPatch: def memory_required(self): counter = [0] for patch in self.patches[self.key]: - comfy.lora.prefetch_prepared_value(patch[1], counter, None, None) + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) return counter[0] - def prepare(self, destination, stream): + def prepare(self, destination, stream, copy=True, commit=True): counter = [0] - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4]) + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None diff --git a/comfy/ops.py b/comfy/ops.py index cb0308acd..6b873cc66 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -203,7 +203,14 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin lowvram_fn = getattr(s, param_key + "_lowvram_function", None) if lowvram_fn is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream) + lowvram_size = lowvram_fn.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True) + pin_offset = get_stream_pin_buffer_offset(lowvram_size) + if pin_offset is not None: + stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest)) + else: + lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True) prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -211,19 +218,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) + if stream_pin_offset > 0: - stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size) - if stream_pin_hostbuf_size < stream_pin_offset: - stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM - if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size): + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): for xfer_source, _, _, xfer_dest in stream_pin_queue: - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) return offload_stream stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset) stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] - comfy.model_management.cast_to_gathered(xfer_source, pin) + cast_maybe_lowvram_patch(xfer_source, pin, None) comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) return offload_stream