mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
lora: re-implement as inplace swiss-army-knife operation
This commit is contained in:
parent
3b05f58788
commit
01eba77dba
@ -475,7 +475,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, counter, destination, stream):
|
||||
def prefetch_prepared_value(value, counter, destination, stream, copy):
|
||||
if isinstance(value, torch.Tensor):
|
||||
size = comfy.memory_management.vram_aligned_size(value)
|
||||
offset = counter[0]
|
||||
@ -484,14 +484,14 @@ def prefetch_prepared_value(value, counter, destination, stream):
|
||||
return value
|
||||
|
||||
dest = destination[offset:offset + size]
|
||||
if stream is not None:
|
||||
if copy:
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream))
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value)
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, counter, destination, stream) for item in value]
|
||||
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -1201,7 +1201,7 @@ def get_pin_buffer(offload_stream):
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||
old_size = pin_buffer.size
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
@ -1210,8 +1210,7 @@ def resize_pin_buffer(pin_buffer, size):
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
pin_buffer._comfy_stream_pin_size = size
|
||||
TOTAL_PINNED_MEMORY += growth
|
||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
@ -1231,7 +1230,7 @@ def reset_cast_buffers():
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0)
|
||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
||||
if TOTAL_PINNED_MEMORY < 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
|
||||
@ -117,6 +117,8 @@ def string_to_seed(data):
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
is_lowvram_patch = True
|
||||
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
@ -127,15 +129,18 @@ class LowVramPatch:
|
||||
def memory_required(self):
|
||||
counter = [0]
|
||||
for patch in self.patches[self.key]:
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None)
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
|
||||
return counter[0]
|
||||
|
||||
def prepare(self, destination, stream):
|
||||
def prepare(self, destination, stream, copy=True, commit=True):
|
||||
counter = [0]
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4])
|
||||
prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
if commit:
|
||||
self.prepared_patches = prepared_patches
|
||||
return prepared_patches
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
25
comfy/ops.py
25
comfy/ops.py
@ -203,7 +203,14 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream)
|
||||
lowvram_size = lowvram_fn.memory_required()
|
||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||
lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
pin_offset = get_stream_pin_buffer_offset(lowvram_size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest))
|
||||
else:
|
||||
lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
@ -211,19 +218,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
|
||||
if stream_pin_hostbuf_size < stream_pin_offset:
|
||||
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size):
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
|
||||
return offload_stream
|
||||
|
||||
Loading…
Reference in New Issue
Block a user