mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-14 20:09:24 +08:00
Merge remote-tracking branch 'upstream/master' into gaussians
This commit is contained in:
commit
7a2e908bae
@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
|
|||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||||
|
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import dataclasses
|
|||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import comfy_aimdo.host_buffer
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple):
|
|||||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||||
|
|
||||||
if isinstance(tensor, QuantizedTensor):
|
if isinstance(tensor, QuantizedTensor):
|
||||||
if not isinstance(destination, QuantizedTensor):
|
if not read_tensor_file_slice_into(tensor._qdata,
|
||||||
return False
|
destination._qdata if destination is not None else None, stream=stream,
|
||||||
if tensor._layout_cls != destination._layout_cls:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
|
|
||||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
dst_orig_dtype = destination._params.orig_dtype
|
if destination is not None:
|
||||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
dst_orig_dtype = destination._params.orig_dtype
|
||||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||||
|
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||||
if destination2 is not None:
|
if destination2 is not None:
|
||||||
dst_orig_dtype = destination2._params.orig_dtype
|
dst_orig_dtype = destination2._params.orig_dtype
|
||||||
destination2._params.copy_from(destination._params, non_blocking=True)
|
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
|
||||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
if info is None:
|
if info is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if destination is not None and destination.device.type != "cpu" and destination2 is None:
|
||||||
|
destination2 = destination
|
||||||
|
destination = None
|
||||||
|
|
||||||
file_obj = info.file_ref
|
file_obj = info.file_ref
|
||||||
if (destination.device.type != "cpu"
|
if (file_obj is None
|
||||||
or file_obj is None
|
or (destination is None and destination2 is None)
|
||||||
or destination.numel() * destination.element_size() < info.size
|
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
|
||||||
|
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
|
||||||
or tensor.numel() * tensor.element_size() != info.size
|
or tensor.numel() * tensor.element_size() != info.size
|
||||||
or tensor.storage_offset() != 0
|
or tensor.storage_offset() != 0
|
||||||
or not tensor.is_contiguous()):
|
or not tensor.is_contiguous()):
|
||||||
@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
if info.size == 0:
|
if info.size == 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if destination is None:
|
||||||
|
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||||
|
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
|
||||||
|
stream_ptr, destination2.data_ptr(),
|
||||||
|
destination2.device.index,
|
||||||
|
mark_cold=False)
|
||||||
|
return True
|
||||||
|
|
||||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||||
if hostbuf is not None:
|
if hostbuf is not None:
|
||||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||||
@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
|||||||
device=None if destination2 is None else destination2.device.index)
|
device=None if destination2 is None else destination2.device.index)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
|
||||||
|
return False
|
||||||
|
|
||||||
buf_type = ctypes.c_ubyte * info.size
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
|
|||||||
@ -641,14 +641,17 @@ def free_pins(size, evict_active=False):
|
|||||||
return freed_total
|
return freed_total
|
||||||
|
|
||||||
def ensure_pin_budget(size, evict_active=False):
|
def ensure_pin_budget(size, evict_active=False):
|
||||||
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
|
if args.fast_disk:
|
||||||
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||||
|
else:
|
||||||
|
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
||||||
if shortfall <= 0:
|
if shortfall <= 0:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||||
|
|
||||||
def ensure_pin_registerable(size, evict_active=False):
|
def ensure_pin_registerable(size, evict_active=True):
|
||||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
@ -658,10 +661,17 @@ def ensure_pin_registerable(size, evict_active=False):
|
|||||||
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||||
for loaded_model in reversed(current_loaded_models):
|
for loaded_model in reversed(current_loaded_models):
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
||||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||||
if shortfall <= 0:
|
if shortfall <= 0:
|
||||||
return True
|
return True
|
||||||
|
if evict_active:
|
||||||
|
for loaded_model in current_loaded_models:
|
||||||
|
model = loaded_model.model
|
||||||
|
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
||||||
|
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||||
|
if shortfall <= 0:
|
||||||
|
return True
|
||||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
@ -803,9 +813,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
for x in can_unload_sorted:
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
|
if not DISABLE_SMART_MEMORY or device is None:
|
||||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||||
if for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||||
@ -817,6 +827,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
|||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
|
|
||||||
|
if not for_dynamic and pins_required > 0:
|
||||||
|
ensure_pin_budget(pins_required)
|
||||||
|
ensure_pin_registerable(pins_required)
|
||||||
|
|
||||||
if len(unloaded_model) > 0:
|
if len(unloaded_model) > 0:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
elif device is not None:
|
elif device is not None:
|
||||||
@ -879,15 +893,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
model_to_unload.model_finalizer.detach()
|
model_to_unload.model_finalizer.detach()
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
|
total_pins_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
device = loaded_model.device
|
device = loaded_model.device
|
||||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||||
|
if not loaded_model.model.is_dynamic():
|
||||||
|
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||||
device,
|
device,
|
||||||
for_dynamic=free_for_dynamic)
|
for_dynamic=free_for_dynamic,
|
||||||
|
pins_required=total_pins_required.get(device, 0))
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
@ -1283,7 +1301,6 @@ STREAM_CAST_BUFFERS = {}
|
|||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
STREAM_PIN_BUFFERS = {}
|
|
||||||
|
|
||||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||||
|
|
||||||
@ -1326,42 +1343,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
|||||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
return cast_buffer
|
return cast_buffer
|
||||||
|
|
||||||
def get_pin_buffer(offload_stream):
|
|
||||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
|
||||||
if pin_buffer is None:
|
|
||||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
|
|
||||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
|
||||||
elif offload_stream is not None:
|
|
||||||
event = getattr(pin_buffer, "_comfy_event", None)
|
|
||||||
if event is not None:
|
|
||||||
event.synchronize()
|
|
||||||
delattr(pin_buffer, "_comfy_event")
|
|
||||||
return pin_buffer
|
|
||||||
|
|
||||||
def resize_pin_buffer(pin_buffer, size):
|
|
||||||
global TOTAL_PINNED_MEMORY
|
|
||||||
old_size = pin_buffer.size
|
|
||||||
if size <= old_size:
|
|
||||||
return True
|
|
||||||
growth = size - old_size
|
|
||||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
|
||||||
ensure_pin_budget(growth, evict_active=True)
|
|
||||||
ensure_pin_registerable(growth, evict_active=True)
|
|
||||||
try:
|
|
||||||
pin_buffer.extend(size=size, reallocate=True)
|
|
||||||
except RuntimeError:
|
|
||||||
return False
|
|
||||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
|
||||||
return True
|
|
||||||
|
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
global TOTAL_PINNED_MEMORY
|
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
|
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
|
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
offload_stream.synchronize()
|
offload_stream.synchronize()
|
||||||
synchronize()
|
synchronize()
|
||||||
@ -1370,20 +1358,24 @@ def reset_cast_buffers():
|
|||||||
mmap_obj.bounce()
|
mmap_obj.bounce()
|
||||||
DIRTY_MMAPS.clear()
|
DIRTY_MMAPS.clear()
|
||||||
|
|
||||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
|
||||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
|
||||||
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
|
|
||||||
|
|
||||||
for loaded_model in current_loaded_models:
|
for loaded_model in current_loaded_models:
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
if model is not None and model.is_dynamic():
|
if model is not None and model.is_dynamic():
|
||||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
pin_state = model.model.dynamic_pins[model.load_device]
|
||||||
|
|
||||||
|
if pin_state["active"]:
|
||||||
|
*_, buckets = pin_state["weights"]
|
||||||
|
for size, bucket in list(buckets.items()):
|
||||||
|
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
||||||
|
if not bucket:
|
||||||
|
del buckets[size]
|
||||||
|
|
||||||
|
pin_state["active"] = False
|
||||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
|
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
STREAM_PIN_BUFFERS.clear()
|
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
@ -1436,7 +1428,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
|||||||
if hasattr(wf_context, "as_context"):
|
if hasattr(wf_context, "as_context"):
|
||||||
wf_context = wf_context.as_context(stream)
|
wf_context = wf_context.as_context(stream)
|
||||||
|
|
||||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
||||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for tensor in tensors:
|
for tensor in tensors:
|
||||||
@ -1448,9 +1440,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
|||||||
continue
|
continue
|
||||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||||
mark_mmap_dirty(storage)
|
mark_mmap_dirty(storage)
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
if dest_view is not None:
|
||||||
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
if dest2_view is not None:
|
if dest2_view is not None:
|
||||||
dest2_view.copy_(dest_view, non_blocking=non_blocking)
|
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||||
|
|||||||
@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
"""
|
"""
|
||||||
if device not in self.model.dynamic_pins:
|
if device not in self.model.dynamic_pins:
|
||||||
self.model.dynamic_pins[device] = {
|
self.model.dynamic_pins[device] = {
|
||||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||||
"hostbufs_initialized": False,
|
"hostbufs_initialized": False,
|
||||||
"failed": False,
|
"failed": False,
|
||||||
"active": False,
|
"active": False,
|
||||||
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
pin_state = self.model.dynamic_pins[self.load_device]
|
pin_state = self.model.dynamic_pins[self.load_device]
|
||||||
if not pin_state["hostbufs_initialized"]:
|
if not pin_state["hostbufs_initialized"]:
|
||||||
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
||||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||||
pin_state["hostbufs_initialized"] = True
|
pin_state["hostbufs_initialized"] = True
|
||||||
pin_state["failed"] = False
|
pin_state["failed"] = False
|
||||||
pin_state["active"] = True
|
pin_state["active"] = True
|
||||||
@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return freed
|
return freed
|
||||||
|
|
||||||
def loaded_ram_size(self):
|
def loaded_ram_size(self):
|
||||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
|
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
|
||||||
self.model.dynamic_pins[self.load_device]["patches"][0].size)
|
|
||||||
|
|
||||||
def pinned_memory_size(self):
|
def pinned_memory_size(self):
|
||||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
|
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
|
||||||
self.model.dynamic_pins[self.load_device]["patches"][3][0])
|
|
||||||
|
|
||||||
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||||
freed = 0
|
freed = 0
|
||||||
pin_state = self.model.dynamic_pins[self.load_device]
|
pin_state = self.model.dynamic_pins[self.load_device]
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||||
split = stack_split[0]
|
split = stack_split[0]
|
||||||
while split >= 0:
|
while split >= 0:
|
||||||
module, offset = stack[split]
|
module, offset = stack[split]
|
||||||
@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
freed = 0
|
freed = 0
|
||||||
pin_state = self.model.dynamic_pins[self.load_device]
|
pin_state = self.model.dynamic_pins[self.load_device]
|
||||||
for subset in subsets:
|
for subset in subsets:
|
||||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||||
while len(stack) > 0:
|
while len(stack) > 0:
|
||||||
module, offset = stack.pop()
|
module, offset = stack.pop()
|
||||||
size = module._pin.numel() * module._pin.element_size()
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
|
module._pin_balancer_entry[-1] = None
|
||||||
|
del module._pin_balancer_entry
|
||||||
del module._pin
|
del module._pin
|
||||||
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
||||||
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import comfy_aimdo.model_vbar
|
import comfy_aimdo.model_vbar
|
||||||
|
import comfy.memory_management
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
|
|||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
comfy_modules.append(s)
|
comfy_modules.append(s)
|
||||||
|
|
||||||
|
registerable_size = 0
|
||||||
|
for s in comfy_modules:
|
||||||
|
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
|
||||||
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
registerable_size += lowvram_fn.memory_required()
|
||||||
|
|
||||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||||
|
if not comfy.model_management.args.fast_disk:
|
||||||
|
comfy.model_management.ensure_pin_registerable(registerable_size)
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||||
|
|
||||||
|
|||||||
66
comfy/ops.py
66
comfy/ops.py
@ -76,8 +76,6 @@ except:
|
|||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
|
|
||||||
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
offload_stream = None
|
offload_stream = None
|
||||||
cast_buffer = None
|
cast_buffer = None
|
||||||
cast_buffer_offset = 0
|
cast_buffer_offset = 0
|
||||||
stream_pin_hostbuf = None
|
|
||||||
stream_pin_offset = 0
|
|
||||||
stream_pin_queue = []
|
|
||||||
|
|
||||||
def ensure_offload_stream(module, required_size, check_largest):
|
def ensure_offload_stream(module, required_size, check_largest):
|
||||||
nonlocal offload_stream
|
nonlocal offload_stream
|
||||||
@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
cast_buffer_offset += buffer_size
|
cast_buffer_offset += buffer_size
|
||||||
return buffer
|
return buffer
|
||||||
|
|
||||||
def get_stream_pin_buffer_offset(buffer_size):
|
|
||||||
nonlocal stream_pin_hostbuf
|
|
||||||
nonlocal stream_pin_offset
|
|
||||||
|
|
||||||
if buffer_size == 0 or offload_stream is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if stream_pin_hostbuf is None:
|
|
||||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
|
||||||
if stream_pin_hostbuf is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
offset = stream_pin_offset
|
|
||||||
stream_pin_offset += buffer_size
|
|
||||||
return offset
|
|
||||||
|
|
||||||
for s in comfy_modules:
|
for s in comfy_modules:
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
if xfer_dest is None:
|
if xfer_dest is None:
|
||||||
xfer_dest = get_cast_buffer(dest_size)
|
xfer_dest = get_cast_buffer(dest_size)
|
||||||
|
|
||||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
|
||||||
if xfer_source is not None:
|
if xfer_source is not None:
|
||||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
if xfer_dest is not None:
|
||||||
else:
|
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
xfer_source = [ xfer_dest ]
|
||||||
|
xfer_dest = xfer_dest2
|
||||||
|
xfer_dest2 = None
|
||||||
|
elif xfer_dest2 is not None:
|
||||||
|
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
|
||||||
|
return
|
||||||
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
|
||||||
|
|
||||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||||
if pin is not None:
|
if pin is not None:
|
||||||
@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
if signature is None:
|
if signature is None:
|
||||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||||
if pin is not None:
|
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||||
if isinstance(source, list):
|
|
||||||
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
|
|
||||||
else:
|
|
||||||
cast_maybe_lowvram_patch(source, pin, None)
|
|
||||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
|
||||||
return
|
|
||||||
if pin is None:
|
|
||||||
pin_offset = get_stream_pin_buffer_offset(size)
|
|
||||||
if pin_offset is not None:
|
|
||||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
|
||||||
return
|
|
||||||
cast_maybe_lowvram_patch(source, dest, offload_stream)
|
|
||||||
|
|
||||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||||
|
|
||||||
@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
|||||||
prefetch["needs_cast"] = needs_cast
|
prefetch["needs_cast"] = needs_cast
|
||||||
s._prefetch = prefetch
|
s._prefetch = prefetch
|
||||||
|
|
||||||
if stream_pin_offset > 0:
|
|
||||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
|
||||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
|
||||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
|
||||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
|
||||||
return offload_stream
|
|
||||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
|
|
||||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
|
||||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
|
||||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
|
||||||
if isinstance(xfer_source, list):
|
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
|
|
||||||
else:
|
|
||||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
|
||||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
|
||||||
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
|
|
||||||
|
|
||||||
return offload_stream
|
return offload_stream
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,55 @@
|
|||||||
|
import bisect
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
import comfy.utils
|
||||||
import comfy_aimdo.host_buffer
|
import comfy_aimdo.host_buffer
|
||||||
import comfy_aimdo.torch
|
import comfy_aimdo.torch
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
def _add_to_bucket(module, buckets, size, priority):
|
||||||
|
bucket = buckets.setdefault(size, [])
|
||||||
|
entry = [-priority, 0, module]
|
||||||
|
entry[1] = id(entry)
|
||||||
|
bisect.insort(bucket, entry)
|
||||||
|
module._pin_balancer_entry = entry
|
||||||
|
|
||||||
|
def _steal_pin(module, stack, buckets, size, priority):
|
||||||
|
bucket = buckets.get(size)
|
||||||
|
if bucket is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
while bucket and bucket[-1][-1] is None:
|
||||||
|
bucket.pop()
|
||||||
|
if not bucket:
|
||||||
|
del buckets[size]
|
||||||
|
return False
|
||||||
|
|
||||||
|
if priority <= -bucket[-1][0]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
*_, victim = bucket.pop()
|
||||||
|
module._pin = victim._pin
|
||||||
|
module._pin_registered = victim._pin_registered
|
||||||
|
module._pin_stack_index = victim._pin_stack_index
|
||||||
|
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
|
||||||
|
|
||||||
|
victim._pin_registered = False
|
||||||
|
del victim._pin
|
||||||
|
del victim._pin_stack_index
|
||||||
|
del victim._pin_balancer_entry
|
||||||
|
|
||||||
|
_add_to_bucket(module, buckets, size, priority)
|
||||||
|
return True
|
||||||
|
|
||||||
def get_pin(module, subset="weights"):
|
def get_pin(module, subset="weights"):
|
||||||
pin = getattr(module, "_pin", None)
|
pin = getattr(module, "_pin", None)
|
||||||
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
||||||
return pin
|
return pin
|
||||||
|
|
||||||
_, _, stack_split, pinned_size = module._pin_state[subset]
|
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
|
||||||
size = pin.nbytes
|
size = pin.nbytes
|
||||||
comfy.model_management.ensure_pin_registerable(size)
|
comfy.model_management.ensure_pin_registerable(size)
|
||||||
|
|
||||||
@ -31,26 +69,30 @@ def pin_memory(module, subset="weights", size=None):
|
|||||||
return
|
return
|
||||||
|
|
||||||
pin = get_pin(module, subset)
|
pin = get_pin(module, subset)
|
||||||
if pin is not None or pin_state["failed"]:
|
if pin is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
|
||||||
if size is None:
|
if size is None:
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
offset = hostbuf.size
|
offset = hostbuf.size
|
||||||
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
|
registerable_size = size
|
||||||
|
priority = getattr(module, "_pin_balancer_priority", None)
|
||||||
|
|
||||||
|
if priority is None:
|
||||||
|
priority = comfy.utils.bit_reverse_range(counter[0], 16)
|
||||||
|
counter[0] += 1
|
||||||
|
module._pin_balancer_priority = priority
|
||||||
|
|
||||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||||
if (not comfy.model_management.ensure_pin_budget(size) or
|
if (not comfy.model_management.ensure_pin_budget(size) or
|
||||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||||
pin_state["failed"] = True
|
return _steal_pin(module, stack, buckets, size, priority)
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
hostbuf.extend(size=size)
|
hostbuf.extend(size=size)
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pin_state["failed"] = True
|
return _steal_pin(module, stack, buckets, size, priority)
|
||||||
return False
|
|
||||||
|
|
||||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||||
@ -60,4 +102,5 @@ def pin_memory(module, subset="weights", size=None):
|
|||||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||||
pinned_size[0] += size
|
pinned_size[0] += size
|
||||||
|
_add_to_bucket(module, buckets, size, priority)
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -85,9 +85,9 @@ _TYPES = {
|
|||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
import comfy_aimdo.model_mmap
|
import comfy_aimdo.model_mmap
|
||||||
|
|
||||||
f = open(ckpt, "rb", buffering=0)
|
|
||||||
file_lock = threading.Lock()
|
file_lock = threading.Lock()
|
||||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
|
f = model_mmap.get_file_handle()
|
||||||
file_size = os.path.getsize(ckpt)
|
file_size = os.path.getsize(ckpt)
|
||||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
|
|
||||||
@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None):
|
|||||||
|
|
||||||
memo[obj_id] = res
|
memo[obj_id] = res
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def bit_reverse_range(index, bits):
|
||||||
|
result = 0
|
||||||
|
for _ in range(bits):
|
||||||
|
result = (result << 1) | (index & 1)
|
||||||
|
index >>= 1
|
||||||
|
return result
|
||||||
|
|||||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
|||||||
filelock
|
filelock
|
||||||
av>=16.0.0
|
av>=16.0.0
|
||||||
comfy-kitchen==0.2.10
|
comfy-kitchen==0.2.10
|
||||||
comfy-aimdo==0.4.5
|
comfy-aimdo==0.4.7
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user