mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 23:12:35 +08:00
mm: factor out the current stream getter
Make this a reusable function.
This commit is contained in:
parent
e525673f72
commit
313638d13a
@ -1013,6 +1013,16 @@ if args.async_offload:
|
|||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
def current_stream(device):
|
||||||
|
if device is None:
|
||||||
|
return None
|
||||||
|
if is_device_cuda(device):
|
||||||
|
return torch.cuda.current_stream()
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
return torch.xpu.current_stream()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1023,10 +1033,7 @@ def get_offload_stream(device):
|
|||||||
ss = STREAMS[device]
|
ss = STREAMS[device]
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
ss[stream_counter].wait_stream(current_stream(device))
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
@ -1050,12 +1057,9 @@ def get_offload_stream(device):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_stream(device, stream):
|
def sync_stream(device, stream):
|
||||||
if stream is None:
|
if stream is None or current_stream(device) is None:
|
||||||
return
|
return
|
||||||
if is_device_cuda(device):
|
current_stream(device).wait_stream(stream)
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
torch.xpu.current_stream().wait_stream(stream)
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user