mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
LowVRAMPatch: change to two-phase visit
This commit is contained in:
parent
663aaf8b96
commit
3b05f58788
@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
def prefetch_prepared_value(value, counter, destination, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
size = comfy.memory_management.vram_aligned_size(value)
|
||||
offset = counter[0]
|
||||
counter[0] += size
|
||||
if destination is None:
|
||||
return value
|
||||
|
||||
dest = destination[offset:offset + size]
|
||||
if stream is not None:
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
return [prefetch_prepared_value(item, counter, destination, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -124,9 +124,16 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
def memory_required(self):
|
||||
counter = [0]
|
||||
for patch in self.patches[self.key]:
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None)
|
||||
return counter[0]
|
||||
|
||||
def prepare(self, destination, stream):
|
||||
counter = [0]
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
|
||||
Loading…
Reference in New Issue
Block a user