From 3b05f58788dce5a4f607a72a7b68e9cc7fc02d07 Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 7 May 2026 23:50:37 +1000 Subject: [PATCH] LowVRAMPatch: change to two-phase visit --- comfy/lora.py | 19 +++++++++++++------ comfy/model_patcher.py | 11 +++++++++-- comfy/ops.py | 2 +- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/comfy/lora.py b/comfy/lora.py index db8f16bcb..f7c7c21a5 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if stream is not None: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream) for item in value] return value diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0b590138d..995b1b36b 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -124,9 +124,16 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None) + return counter[0] + + def prepare(self, destination, stream): + counter = [0] self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] diff --git a/comfy/ops.py b/comfy/ops.py index e060d9ba9..cb0308acd 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -203,7 +203,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin lowvram_fn = getattr(s, param_key + "_lowvram_function", None) if lowvram_fn is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream) prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest