mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-25 16:37:23 +08:00
ops: implement block prefetching API
allow a model to construct a prefetch list and operate it for increased async offload.
This commit is contained in:
parent
0e93c88c67
commit
74261f12f2
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
import comfy_aimdo.vram_buffer
|
import comfy_aimdo.vram_buffer
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
@ -1184,6 +1185,7 @@ STREAM_CAST_BUFFERS = {}
|
|||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||||
|
|
||||||
@ -1227,9 +1229,30 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
|||||||
|
|
||||||
return cast_buffer
|
return cast_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_prefetched_modules(comfy_modules):
|
||||||
|
for s in comfy_modules:
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
if prefetch is None:
|
||||||
|
continue
|
||||||
|
if prefetch["signature"] is not None:
|
||||||
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_prefetch_queue(queue):
|
||||||
|
for entry in queue:
|
||||||
|
if entry is None or not isinstance(entry, tuple):
|
||||||
|
continue
|
||||||
|
_, prefetch_state = entry
|
||||||
|
comfy_modules = prefetch_state[1]
|
||||||
|
if comfy_modules is not None:
|
||||||
|
cleanup_prefetched_modules(comfy_modules)
|
||||||
|
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
|
global PREFETCH_QUEUES
|
||||||
|
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
@ -1237,6 +1260,11 @@ def reset_cast_buffers():
|
|||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|
||||||
|
for queue in PREFETCH_QUEUES:
|
||||||
|
cleanup_prefetch_queue(queue)
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|||||||
45
comfy/ops.py
45
comfy/ops.py
@ -164,6 +164,51 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
return offload_stream
|
return offload_stream
|
||||||
|
|
||||||
|
|
||||||
|
def cast_prefetch_all(module, device):
|
||||||
|
prefetch_state = (module, None)
|
||||||
|
if (comfy.model_management.is_device_cpu(device)
|
||||||
|
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||||
|
return (None, prefetch_state)
|
||||||
|
|
||||||
|
comfy_modules = []
|
||||||
|
for s in module.modules():
|
||||||
|
if hasattr(s, "_v"):
|
||||||
|
comfy_modules.append(s)
|
||||||
|
|
||||||
|
offload_stream = cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||||
|
return (offload_stream, (module, comfy_modules))
|
||||||
|
|
||||||
|
|
||||||
|
def uncast_prefetch_all(prefetch_state):
|
||||||
|
_, comfy_modules = prefetch_state
|
||||||
|
if comfy_modules is not None:
|
||||||
|
comfy.model_management.cleanup_prefetched_modules(comfy_modules)
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_queue_pop(queue, device, module):
|
||||||
|
consumed = queue.pop(0)
|
||||||
|
if consumed is not None:
|
||||||
|
offload_stream, prefetch_state = consumed
|
||||||
|
if offload_stream is not None:
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
uncast_prefetch_all(prefetch_state)
|
||||||
|
|
||||||
|
active = queue[0]
|
||||||
|
if active is not None:
|
||||||
|
offload_stream, prefetch_state = active
|
||||||
|
assert prefetch_state[0] is module
|
||||||
|
if offload_stream is not None:
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
prefetch = queue[1]
|
||||||
|
if prefetch is not None:
|
||||||
|
queue[1] = cast_prefetch_all(prefetch, device)
|
||||||
|
|
||||||
|
|
||||||
|
def make_prefetch_queue(queue):
|
||||||
|
queue = [None, None] + queue + [None, None]
|
||||||
|
comfy.model_management.PREFETCH_QUEUES.append(queue)
|
||||||
|
return queue
|
||||||
|
|
||||||
|
|
||||||
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user