diff --git a/comfy/model_management.py b/comfy/model_management.py index afda2f086..73c18c139 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.model_vbar import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -1184,6 +1185,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +PREFETCH_QUEUES = [] DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1227,9 +1229,30 @@ def get_aimdo_cast_buffer(offload_stream, device): return cast_buffer + +def cleanup_prefetched_modules(comfy_modules): + for s in comfy_modules: + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + continue + if prefetch["signature"] is not None: + comfy_aimdo.model_vbar.vbar_unpin(s._v) + delattr(s, "_prefetch") + + +def cleanup_prefetch_queue(queue): + for entry in queue: + if entry is None or not isinstance(entry, tuple): + continue + _, prefetch_state = entry + comfy_modules = prefetch_state[1] + if comfy_modules is not None: + cleanup_prefetched_modules(comfy_modules) + def reset_cast_buffers(): global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT + global PREFETCH_QUEUES LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) @@ -1237,6 +1260,11 @@ def reset_cast_buffers(): if offload_stream is not None: offload_stream.synchronize() synchronize() + + for queue in PREFETCH_QUEUES: + cleanup_prefetch_queue(queue) + PREFETCH_QUEUES = [] + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() soft_empty_cache() diff --git a/comfy/ops.py b/comfy/ops.py index 6baad439f..2e0b45be4 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -164,6 +164,51 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin return offload_stream +def cast_prefetch_all(module, device): + prefetch_state = (module, None) + if (comfy.model_management.is_device_cpu(device) + or not comfy.model_management.device_supports_non_blocking(device)): + return (None, prefetch_state) + + comfy_modules = [] + for s in module.modules(): + if hasattr(s, "_v"): + comfy_modules.append(s) + + offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True) + return (offload_stream, (module, comfy_modules)) + + +def uncast_prefetch_all(prefetch_state): + _, comfy_modules = prefetch_state + if comfy_modules is not None: + comfy.model_management.cleanup_prefetched_modules(comfy_modules) + + +def prefetch_queue_pop(queue, device, module): + consumed = queue.pop(0) + if consumed is not None: + offload_stream, prefetch_state = consumed + if offload_stream is not None: + offload_stream.wait_stream(comfy.model_management.current_stream(device)) + uncast_prefetch_all(prefetch_state) + + active = queue[0] + if active is not None: + offload_stream, prefetch_state = active + assert prefetch_state[0] is module + if offload_stream is not None: + comfy.model_management.sync_stream(device, offload_stream) + + prefetch = queue[1] + if prefetch is not None: + queue[1] = cast_prefetch_all(prefetch, device) + + +def make_prefetch_queue(queue): + queue = [None, None] + queue + [None, None] + comfy.model_management.PREFETCH_QUEUES.append(queue) + return queue def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):