mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-04 21:37:40 +08:00
Merge remote-tracking branch 'upstream/master' into gaussians
This commit is contained in:
commit
7a2e908bae
@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple):
|
||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
|
||||
if not read_tensor_file_slice_into(tensor._qdata,
|
||||
destination._qdata if destination is not None else None, stream=stream,
|
||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination is not None:
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination2 is not None:
|
||||
dst_orig_dtype = destination2._params.orig_dtype
|
||||
destination2._params.copy_from(destination._params, non_blocking=True)
|
||||
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
|
||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
if destination is not None and destination.device.type != "cpu" and destination2 is None:
|
||||
destination2 = destination
|
||||
destination = None
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
if (file_obj is None
|
||||
or (destination is None and destination2 is None)
|
||||
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
|
||||
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
if destination is None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
|
||||
stream_ptr, destination2.data_ptr(),
|
||||
destination2.device.index,
|
||||
mark_cold=False)
|
||||
return True
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
|
||||
return False
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
|
||||
@ -641,14 +641,17 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=False):
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
@ -658,10 +661,17 @@ def ensure_pin_registerable(size, evict_active=False):
|
||||
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
if evict_active:
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
class LoadedModel:
|
||||
@ -803,9 +813,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
if for_dynamic:
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
@ -817,6 +827,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
|
||||
if not for_dynamic and pins_required > 0:
|
||||
ensure_pin_budget(pins_required)
|
||||
ensure_pin_registerable(pins_required)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
elif device is not None:
|
||||
@ -879,15 +893,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
if not loaded_model.model.is_dynamic():
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic)
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required.get(device, 0))
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -1283,7 +1301,6 @@ STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_PIN_BUFFERS = {}
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
@ -1326,42 +1343,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
return cast_buffer
|
||||
|
||||
def get_pin_buffer(offload_stream):
|
||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||
if pin_buffer is None:
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
|
||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||
elif offload_stream is not None:
|
||||
event = getattr(pin_buffer, "_comfy_event", None)
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
delattr(pin_buffer, "_comfy_event")
|
||||
return pin_buffer
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = pin_buffer.size
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
ensure_pin_budget(growth, evict_active=True)
|
||||
ensure_pin_registerable(growth, evict_active=True)
|
||||
try:
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
global TOTAL_PINNED_MEMORY
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
@ -1370,20 +1358,24 @@ def reset_cast_buffers():
|
||||
mmap_obj.bounce()
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
||||
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||
pin_state = model.model.dynamic_pins[model.load_device]
|
||||
|
||||
if pin_state["active"]:
|
||||
*_, buckets = pin_state["weights"]
|
||||
for size, bucket in list(buckets.items()):
|
||||
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
|
||||
pin_state["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
STREAM_PIN_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1436,7 +1428,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
@ -1448,9 +1440,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
mark_mmap_dirty(storage)
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest_view is not None:
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest2_view is not None:
|
||||
dest2_view.copy_(dest_view, non_blocking=non_blocking)
|
||||
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
|
||||
@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
"""
|
||||
if device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"hostbufs_initialized": False,
|
||||
"failed": False,
|
||||
"active": False,
|
||||
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
if not pin_state["hostbufs_initialized"]:
|
||||
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["hostbufs_initialized"] = True
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return freed
|
||||
|
||||
def loaded_ram_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][0].size)
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
|
||||
|
||||
def pinned_memory_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][3][0])
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
|
||||
|
||||
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
split = stack_split[0]
|
||||
while split >= 0:
|
||||
module, offset = stack[split]
|
||||
@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
module._pin_balancer_entry[-1] = None
|
||||
del module._pin_balancer_entry
|
||||
del module._pin
|
||||
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
||||
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
registerable_size = 0
|
||||
for s in comfy_modules:
|
||||
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
registerable_size += lowvram_fn.memory_required()
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
if not comfy.model_management.args.fast_disk:
|
||||
comfy.model_management.ensure_pin_registerable(registerable_size)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
|
||||
66
comfy/ops.py
66
comfy/ops.py
@ -76,8 +76,6 @@ except:
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
offload_stream = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
stream_pin_hostbuf = None
|
||||
stream_pin_offset = 0
|
||||
stream_pin_queue = []
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
def get_stream_pin_buffer_offset(buffer_size):
|
||||
nonlocal stream_pin_hostbuf
|
||||
nonlocal stream_pin_offset
|
||||
|
||||
if buffer_size == 0 or offload_stream is None:
|
||||
return None
|
||||
|
||||
if stream_pin_hostbuf is None:
|
||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||
if stream_pin_hostbuf is None:
|
||||
return None
|
||||
|
||||
offset = stream_pin_offset
|
||||
stream_pin_offset += buffer_size
|
||||
return offset
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
if xfer_dest is not None:
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
xfer_source = [ xfer_dest ]
|
||||
xfer_dest = xfer_dest2
|
||||
xfer_dest2 = None
|
||||
elif xfer_dest2 is not None:
|
||||
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
|
||||
return
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
|
||||
|
||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||
if pin is not None:
|
||||
@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
if isinstance(source, list):
|
||||
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
return
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return
|
||||
cast_maybe_lowvram_patch(source, dest, offload_stream)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
|
||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
if isinstance(xfer_source, list):
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
|
||||
@ -1,17 +1,55 @@
|
||||
import bisect
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def _add_to_bucket(module, buckets, size, priority):
|
||||
bucket = buckets.setdefault(size, [])
|
||||
entry = [-priority, 0, module]
|
||||
entry[1] = id(entry)
|
||||
bisect.insort(bucket, entry)
|
||||
module._pin_balancer_entry = entry
|
||||
|
||||
def _steal_pin(module, stack, buckets, size, priority):
|
||||
bucket = buckets.get(size)
|
||||
if bucket is None:
|
||||
return False
|
||||
|
||||
while bucket and bucket[-1][-1] is None:
|
||||
bucket.pop()
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
return False
|
||||
|
||||
if priority <= -bucket[-1][0]:
|
||||
return False
|
||||
|
||||
*_, victim = bucket.pop()
|
||||
module._pin = victim._pin
|
||||
module._pin_registered = victim._pin_registered
|
||||
module._pin_stack_index = victim._pin_stack_index
|
||||
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
|
||||
|
||||
victim._pin_registered = False
|
||||
del victim._pin
|
||||
del victim._pin_stack_index
|
||||
del victim._pin_balancer_entry
|
||||
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
def get_pin(module, subset="weights"):
|
||||
pin = getattr(module, "_pin", None)
|
||||
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
||||
return pin
|
||||
|
||||
_, _, stack_split, pinned_size = module._pin_state[subset]
|
||||
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
|
||||
size = pin.nbytes
|
||||
comfy.model_management.ensure_pin_registerable(size)
|
||||
|
||||
@ -31,26 +69,30 @@ def pin_memory(module, subset="weights", size=None):
|
||||
return
|
||||
|
||||
pin = get_pin(module, subset)
|
||||
if pin is not None or pin_state["failed"]:
|
||||
if pin is not None:
|
||||
return
|
||||
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
|
||||
if size is None:
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
|
||||
registerable_size = size
|
||||
priority = getattr(module, "_pin_balancer_priority", None)
|
||||
|
||||
if priority is None:
|
||||
priority = comfy.utils.bit_reverse_range(counter[0], 16)
|
||||
counter[0] += 1
|
||||
module._pin_balancer_priority = priority
|
||||
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
if (not comfy.model_management.ensure_pin_budget(size) or
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
except RuntimeError:
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
@ -60,4 +102,5 @@ def pin_memory(module, subset="weights", size=None):
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
@ -85,9 +85,9 @@ _TYPES = {
|
||||
def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
file_lock = threading.Lock()
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
f = model_mmap.get_file_handle()
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
|
||||
@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def bit_reverse_range(index, bits):
|
||||
result = 0
|
||||
for _ in range(bits):
|
||||
result = (result << 1) | (index & 1)
|
||||
index >>= 1
|
||||
return result
|
||||
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-aimdo==0.4.5
|
||||
comfy-aimdo==0.4.7
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user