ComfyUI/comfy/model_prefetch.py
rattus e154da83b1
Threaded Loader performance fixes / improvements (+ Aimdo 0.4.6) (#14116)
* memory_management: Add direct to read GPU mode

Make destination optional (or make it optionally GPU) and use aimdo
to file_read direct to GPU.

* ops: Remove stream pin buffers and use aimdo reads

This consumed too much RAM and its better to just take the hit on
the CPU syncing back the stream on a short ring buffer. Aimdo
implements this so just rip the stream pin buffer from comfy.

* model_management: all active pin registration movement

Its better to just let the active model load past the pin limit as
pins and let the pins move around. The saves the HDD and SATA
people disk traffic while only costing a few GPU syncs.

* utils: use aimdo file handle

This opens on windows with more favourable flags

* mp: only count the model proper for loaded_ram and vram

Exclude live loras from the numbers to avoid the case where the reported
loaded memory exceeds the size of the model.

This causes me confusion in the Kijai visualizer when it looked fully
loaded but was hitting disk due to this accounding disrepency.

* utils: add bit reverse utility

useful for max scattering something ordered.

* pinned_memory: Implement offload balancing

Use a max scatter alogorithm to prioritize pins of the same size such
that when doing a little bit of offloading it gets scattered, allowing
the prefetcher to more evenly swollow the offload.

* comfy-aimdo 0.4.7

Aimdo 0.4.7 implement VRAM buffer exhaustion predection to avoid
early speculative load of weights that definately wont fix once the
inference gets further in.

* model-prefetch: consolidate pin ensures on the sync point

This could happen mid prefetch block, cause a sync of the entire
block and lose overlap. Get ahead of the problem with a free down
at the natural compute stream sync point.

* mm: Put a 2GB min on the pin ceiling

This is reasonably bad if it starts causing swap pressure, moreso than
during normal ram-cache proceedings. Clamp it.

* add --fast-disk
2026-05-30 15:20:04 -04:00

78 lines
2.8 KiB
Python

import comfy_aimdo.model_vbar
import comfy.memory_management
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
registerable_size = 0
for s in comfy_modules:
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
registerable_size += lowvram_fn.memory_required()
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if not comfy.model_management.args.fast_disk:
comfy.model_management.ensure_pin_registerable(registerable_size)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue