mm: wrap the raw stream in context manager (#10958)

The documentation of torch.foo.Stream being usable with with: suggests
it starts at version 2.7. Use the old API for backwards compatibility.
This commit is contained in:
rattus 2025-11-29 07:38:12 +10:00 committed by GitHub
parent 6484ac89dc
commit 0ff0457892
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 4 deletions

View File

@ -1055,7 +1055,9 @@ def get_offload_stream(device):
elif is_device_cuda(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0))
s1 = torch.cuda.Stream(device=device, priority=0)
s1.as_context = torch.cuda.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
@ -1063,7 +1065,9 @@ def get_offload_stream(device):
elif is_device_xpu(device):
ss = []
for k in range(NUM_STREAMS):
ss.append(torch.xpu.Stream(device=device, priority=0))
s1 = torch.xpu.Stream(device=device, priority=0)
s1.as_context = torch.xpu.stream
ss.append(s1)
STREAMS[device] = ss
s = ss[stream_counter]
stream_counters[device] = stream_counter
@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
if stream is not None:
with stream:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
else:

View File

@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()