mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-12 18:22:53 +08:00
LowVRAMPatch: change to two-phase visit
This commit is contained in:
parent
663aaf8b96
commit
3b05f58788
@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
def prefetch_prepared_value(value, counter, destination, stream):
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
size = comfy.memory_management.vram_aligned_size(value)
|
||||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
offset = counter[0]
|
||||||
|
counter[0] += size
|
||||||
|
if destination is None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
dest = destination[offset:offset + size]
|
||||||
|
if stream is not None:
|
||||||
|
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream))
|
||||||
elif isinstance(value, tuple):
|
elif isinstance(value, tuple):
|
||||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
return [prefetch_prepared_value(item, counter, destination, stream) for item in value]
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|||||||
@ -124,9 +124,16 @@ class LowVramPatch:
|
|||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
self.prepared_patches = None
|
self.prepared_patches = None
|
||||||
|
|
||||||
def prepare(self, allocate_buffer, stream):
|
def memory_required(self):
|
||||||
|
counter = [0]
|
||||||
|
for patch in self.patches[self.key]:
|
||||||
|
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None)
|
||||||
|
return counter[0]
|
||||||
|
|
||||||
|
def prepare(self, destination, stream):
|
||||||
|
counter = [0]
|
||||||
self.prepared_patches = [
|
self.prepared_patches = [
|
||||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4])
|
||||||
for patch in self.patches[self.key]
|
for patch in self.patches[self.key]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -203,7 +203,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
if lowvram_fn is not None:
|
if lowvram_fn is not None:
|
||||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream)
|
||||||
|
|
||||||
prefetch["xfer_dest"] = xfer_dest
|
prefetch["xfer_dest"] = xfer_dest
|
||||||
prefetch["cast_dest"] = cast_dest
|
prefetch["cast_dest"] = cast_dest
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user