mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 16:27:33 +08:00
mm: delegate pin-on-the-way to aimdo
Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault.
This commit is contained in:
parent
ed15d62a6c
commit
43865639a7
@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple):
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
|
||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination2 is not None:
|
||||
dst_orig_dtype = destination2._params.orig_dtype
|
||||
destination2._params.copy_from(destination._params, non_blocking=True)
|
||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
@ -50,8 +55,12 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
device_ptr = destination2.data_ptr() if destination2 is not None else 0
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address())
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address(),
|
||||
stream=stream_ptr,
|
||||
device_ptr=device_ptr)
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
|
||||
@ -1318,7 +1318,7 @@ def sync_stream(device, stream):
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
wf_context = nullcontext()
|
||||
if stream is not None:
|
||||
wf_context = stream
|
||||
@ -1326,16 +1326,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
dest_view = dest_views.pop(0)
|
||||
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
mark_mmap_dirty(storage)
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest2_view is not None:
|
||||
dest2_view.copy_(dest_view, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
|
||||
14
comfy/ops.py
14
comfy/ops.py
@ -198,8 +198,11 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
if isinstance(source, list):
|
||||
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
return
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
@ -238,8 +241,11 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
if isinstance(xfer_source, list):
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
|
||||
|
||||
return offload_stream
|
||||
|
||||
Loading…
Reference in New Issue
Block a user