mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
mm: wrap the raw stream in context manager (#10958)
The documentation of torch.foo.Stream being usable with with: suggests it starts at version 2.7. Use the old API for backwards compatibility.
This commit is contained in:
parent
6484ac89dc
commit
0ff0457892
@ -1055,7 +1055,9 @@ def get_offload_stream(device):
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.cuda.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1063,7 +1065,9 @@ def get_offload_stream(device):
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.xpu.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
|
||||
@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user