LowVRAMPatch: change to two-phase visit

This commit is contained in:
Rattus 2026-05-07 23:50:37 +10:00
parent 663aaf8b96
commit 3b05f58788
3 changed files with 23 additions and 9 deletions

View File

@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
def prefetch_prepared_value(value, counter, destination, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
size = comfy.memory_management.vram_aligned_size(value)
offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if stream is not None:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return [prefetch_prepared_value(item, counter, destination, stream) for item in value]
return value

View File

@ -124,9 +124,16 @@ class LowVramPatch:
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
def memory_required(self):
counter = [0]
for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None)
return counter[0]
def prepare(self, destination, stream):
counter = [0]
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]

View File

@ -203,7 +203,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest