mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
mm: wrap the raw stream in context manager (#10958)
The documentation of torch.foo.Stream being usable with with: suggests it starts at version 2.7. Use the old API for backwards compatibility.
This commit is contained in:
parent
6484ac89dc
commit
0ff0457892
@ -1055,7 +1055,9 @@ def get_offload_stream(device):
|
|||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||||
|
s1.as_context = torch.cuda.stream
|
||||||
|
ss.append(s1)
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
@ -1063,7 +1065,9 @@ def get_offload_stream(device):
|
|||||||
elif is_device_xpu(device):
|
elif is_device_xpu(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||||
|
s1.as_context = torch.xpu.stream
|
||||||
|
ss.append(s1)
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
@ -1081,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
if dtype is None or weight.dtype == dtype:
|
if dtype is None or weight.dtype == dtype:
|
||||||
return weight
|
return weight
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
with stream:
|
wf_context = stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(stream)
|
||||||
|
with wf_context:
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
return weight.to(dtype=dtype, copy=copy)
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
if stream is not None:
|
if stream is not None:
|
||||||
with stream:
|
wf_context = stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(stream)
|
||||||
|
with wf_context:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -95,6 +95,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
|
|
||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
wf_context = offload_stream
|
wf_context = offload_stream
|
||||||
|
if hasattr(wf_context, "as_context"):
|
||||||
|
wf_context = wf_context.as_context(offload_stream)
|
||||||
else:
|
else:
|
||||||
wf_context = contextlib.nullcontext()
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user