import comfy_aimdo.model_vbar import comfy.model_management import comfy.ops PREFETCH_QUEUES = [] def cleanup_prefetched_modules(comfy_modules): for s in comfy_modules: prefetch = getattr(s, "_prefetch", None) if prefetch is None: continue for param_key in ("weight", "bias"): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) if lowvram_fn is not None: lowvram_fn.clear_prepared() if prefetch["signature"] is not None: comfy_aimdo.model_vbar.vbar_unpin(s._v) delattr(s, "_prefetch") def cleanup_prefetch_queues(): global PREFETCH_QUEUES for queue in PREFETCH_QUEUES: for entry in queue: if entry is None or not isinstance(entry, tuple): continue _, prefetch_state = entry comfy_modules = prefetch_state[1] if comfy_modules is not None: cleanup_prefetched_modules(comfy_modules) PREFETCH_QUEUES = [] def prefetch_queue_pop(queue, device, module): if queue is None: return consumed = queue.pop(0) if consumed is not None: offload_stream, prefetch_state = consumed offload_stream.wait_stream(comfy.model_management.current_stream(device)) _, comfy_modules = prefetch_state if comfy_modules is not None: cleanup_prefetched_modules(comfy_modules) prefetch = queue[0] if prefetch is not None: comfy_modules = [] for s in prefetch.modules(): if hasattr(s, "_v"): comfy_modules.append(s) offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) comfy.model_management.sync_stream(device, offload_stream) queue[0] = (offload_stream, (prefetch, comfy_modules)) def make_prefetch_queue(queue, device, transformer_options): if (not transformer_options.get("prefetch_dynamic_vbars", False) or comfy.model_management.NUM_STREAMS == 0 or comfy.model_management.is_device_cpu(device) or not comfy.model_management.device_supports_non_blocking(device)): return None queue = [None] + queue + [None] PREFETCH_QUEUES.append(queue) return queue