lora: re-implement as inplace swiss-army-knife operation

This commit is contained in:
Rattus 2026-05-08 08:19:47 +10:00
parent 3b05f58788
commit 01eba77dba
4 changed files with 35 additions and 20 deletions

View File

@ -475,7 +475,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight
def prefetch_prepared_value(value, counter, destination, stream):
def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor):
size = comfy.memory_management.vram_aligned_size(value)
offset = counter[0]
@ -484,14 +484,14 @@ def prefetch_prepared_value(value, counter, destination, stream):
return value
dest = destination[offset:offset + size]
if stream is not None:
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream))
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, counter, destination, stream) for item in value)
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, counter, destination, stream) for item in value]
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value

View File

@ -1201,7 +1201,7 @@ def get_pin_buffer(offload_stream):
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = getattr(pin_buffer, "_comfy_stream_pin_size", 0)
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
@ -1210,8 +1210,7 @@ def resize_pin_buffer(pin_buffer, size):
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
pin_buffer._comfy_stream_pin_size = size
TOTAL_PINNED_MEMORY += growth
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
@ -1231,7 +1230,7 @@ def reset_cast_buffers():
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= getattr(pin_buffer, "_comfy_stream_pin_size", 0)
TOTAL_PINNED_MEMORY -= pin_buffer.size
if TOTAL_PINNED_MEMORY < 0:
TOTAL_PINNED_MEMORY = 0

View File

@ -117,6 +117,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data)
class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
@ -127,15 +129,18 @@ class LowVramPatch:
def memory_required(self):
counter = [0]
for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None)
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream):
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream), patch[2], patch[3], patch[4])
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self):
self.prepared_patches = None

View File

@ -203,7 +203,14 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(get_cast_buffer(lowvram_fn.memory_required()), offload_stream)
lowvram_size = lowvram_fn.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_fn.prepare(lowvram_dest, None, copy=False, commit=True)
pin_offset = get_stream_pin_buffer_offset(lowvram_size)
if pin_offset is not None:
stream_pin_queue.append((lowvram_fn, pin_offset, lowvram_size, lowvram_dest))
else:
lowvram_fn.prepare(lowvram_dest, offload_stream, copy=True, commit=True)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
@ -211,19 +218,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if stream_pin_offset > 0:
stream_pin_hostbuf_size = getattr(stream_pin_hostbuf, "_comfy_stream_pin_size", stream_pin_hostbuf.size)
if stream_pin_hostbuf_size < stream_pin_offset:
stream_pin_hostbuf_size = stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_hostbuf_size):
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
comfy.model_management.cast_to_gathered(xfer_source, pin)
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
return offload_stream