Implement lora async offload

Implement async offload of loras.
This commit is contained in:
Rattus 2026-04-28 19:50:28 +10:00
parent 9242551e9b
commit d73b819eda
7 changed files with 150 additions and 92 deletions

View File

@ -16,7 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.ops
import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -908,11 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.ops.make_prefetch_queue(list(self.transformer_blocks))
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, block)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -985,7 +985,7 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.ops.prefetch_queue_pop(prefetch_queue, vx.device, None)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]

View File

@ -17,6 +17,7 @@
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@ -467,3 +468,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -31,7 +31,6 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.model_vbar
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
@ -1185,7 +1184,6 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
PREFETCH_QUEUES = []
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1228,31 +1226,9 @@ def get_aimdo_cast_buffer(offload_stream, device):
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queue(queue):
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
global PREFETCH_QUEUES
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
@ -1261,10 +1237,6 @@ def reset_cast_buffers():
offload_stream.synchronize()
synchronize()
for queue in PREFETCH_QUEUES:
cleanup_prefetch_queue(queue)
PREFETCH_QUEUES = []
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

68
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,68 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None:
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if offload_stream is None:
queue[0] = None
return
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device):
if (comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -92,6 +92,38 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -126,23 +158,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None and cast_buffer is None:
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if len(comfy_modules) == 1:
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
if cast_buffer is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device)
cast_buffer_offset += dest_size
else:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -155,6 +173,13 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
@ -164,47 +189,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
return offload_stream
def cast_prefetch_all(module, device):
prefetch_state = (module, None)
if (comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return (None, prefetch_state)
comfy_modules = []
for s in module.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
return (offload_stream, (module, comfy_modules))
def uncast_prefetch_all(prefetch_state):
_, comfy_modules = prefetch_state
if comfy_modules is not None:
comfy.model_management.cleanup_prefetched_modules(comfy_modules)
def prefetch_queue_pop(queue, device, module):
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
uncast_prefetch_all(prefetch_state)
prefetch = queue[0]
if prefetch is not None:
queue[0] = cast_prefetch_all(prefetch, device)
def make_prefetch_queue(queue):
queue = [None] + queue + [None]
comfy.model_management.PREFETCH_QUEUES.append(queue)
return queue
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
del non_blocking
@ -237,6 +221,9 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x
def to_dequant(tensor, dtype):
@ -266,8 +253,7 @@ def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requ
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
bias = None
if s.bias is not None:
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
return weight, bias
@ -321,6 +307,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks: