diff --git a/comfy/model_management.py b/comfy/model_management.py index cd8326f5f..79c0dfdb4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1031,18 +1031,17 @@ def get_offload_stream(device): if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) + #Sync the oldest stream in the queue with the current ss[stream_counter].wait_stream(current_stream(device)) + stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): ss.append(torch.cuda.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): @@ -1051,7 +1050,6 @@ def get_offload_stream(device): ss.append(torch.xpu.Stream(device=device, priority=0)) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None