From 313638d13ab6c0f28e8afb7d2cd7c4c67371b231 Mon Sep 17 00:00:00 2001 From: Rattus Date: Mon, 27 Oct 2025 18:52:07 +1000 Subject: [PATCH] mm: factor out the current stream getter Make this a reusable function. --- comfy/model_management.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3e5b977d4..cd8326f5f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1013,6 +1013,16 @@ if args.async_offload: NUM_STREAMS = 2 logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) @@ -1023,10 +1033,7 @@ def get_offload_stream(device): ss = STREAMS[device] s = ss[stream_counter] stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) + ss[stream_counter].wait_stream(current_stream(device)) stream_counters[device] = stream_counter return s elif is_device_cuda(device): @@ -1050,12 +1057,9 @@ def get_offload_stream(device): return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): if device is None or weight.device == device: