mm: sync the last stream in the queue, not the next

Currently this peeks ahead to sync the next stream in the queue of
streams with the compute stream. This doesnt allow a lot of
parallelization, as then end result is you can only get one weight load
ahead regardless of how many streams you have.

Rotate the loop logic here to synchronize the end of the queue before
returning the next stream. This allows weights to be loaded ahead of the
compute streams position.
This commit is contained in:
Rattus 2025-10-27 18:54:32 +10:00
parent 432a5a9d82
commit 6ceba8c9f8

View File

@ -1031,18 +1031,17 @@ def get_offload_stream(device):
if device in STREAMS:
ss = STREAMS[device]
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
#Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return ss[stream_counter]
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
elif is_device_xpu(device):
@ -1051,7 +1050,6 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss
s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter
return s
return None