From 43865639a74841dfaada96ca49485f570e231b46 Mon Sep 17 00:00:00 2001 From: Rattus Date: Sat, 16 May 2026 02:48:51 +1000 Subject: [PATCH] mm: delegate pin-on-the-way to aimdo Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault. --- comfy/memory_management.py | 15 ++++++++++++--- comfy/model_management.py | 8 ++++++-- comfy/ops.py | 14 ++++++++++---- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 7645064f5..21e3cf59b 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple): size: int -def read_tensor_file_slice_into(tensor, destination): +def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): if not isinstance(destination, QuantizedTensor): @@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination): if tensor._layout_cls != destination._layout_cls: return False - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + destination2=(destination2._qdata if destination2 is not None else None)): return False dst_orig_dtype = destination._params.orig_dtype destination._params.copy_from(tensor._params, non_blocking=False) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination2 is not None: + dst_orig_dtype = destination2._params.orig_dtype + destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) @@ -50,8 +55,12 @@ def read_tensor_file_slice_into(tensor, destination): hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) if hostbuf is not None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + device_ptr = destination2.data_ptr() if destination2 is not None else 0 hostbuf.read_file_slice(file_obj, info.offset, info.size, - offset=destination.data_ptr() - hostbuf.get_raw_address()) + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr) return True buf_type = ctypes.c_ubyte * info.size diff --git a/comfy/model_management.py b/comfy/model_management.py index 19a916362..72ef77ee9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1318,7 +1318,7 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) -def cast_to_gathered(tensors, r, non_blocking=False, stream=None): +def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): wf_context = nullcontext() if stream is not None: wf_context = stream @@ -1326,16 +1326,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): wf_context = wf_context.as_context(stream) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: dest_view = dest_views.pop(0) + dest2_view = dest2_views.pop(0) if dest2_views is not None else None if tensor is None: continue - if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) + if dest2_view is not None: + dest2_view.copy_(dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): diff --git a/comfy/ops.py b/comfy/ops.py index 4b436f4a7..2f364a3a5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -198,8 +198,11 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) if pin is not None: - cast_maybe_lowvram_patch(source, pin, None) - cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + if isinstance(source, list): + comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) + else: + cast_maybe_lowvram_patch(source, pin, None) + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) return if pin is None: pin_offset = get_stream_pin_buffer_offset(size) @@ -238,8 +241,11 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] - cast_maybe_lowvram_patch(xfer_source, pin, None) - comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + if isinstance(xfer_source, list): + comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) + else: + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) stream_pin_hostbuf._comfy_event = offload_stream.record_event() return offload_stream