Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802)

* model_management: disable non-dynamic smart memory

Disable smart memory outright for non dynamic models.

This is a minor step towards deprecation of --disable-dynamic-vram
and the legacy ModelPatcher.

This is needed for estimate-free model development, where new models
can opt-out of supplying a memory estimate and not have to worry
about hard VRAM allocations due to legacy non-dynamic model patchers

This is also a general stability increase for a lot of stray use cases
where estimates may still be off and going forward we are not going
to accurately maintain such estimates.

* pinned_memory: implement with aimdo growable buffer

Use a single growable buffer so we can do threaded pre-warming on
pinned memory.

* mm: use aimdo to do transfer from disk to pin

Aimdo implements a faster threaded loader.

* Add stream host pin buffer for AIMDO casts

Introduce per-offload-stream HostBuffer reuse for pinned staging,
include it in cast buffer reset synchronization.

Defer actual casts that go via this pin path to a separate pass
such that the buffer can be allocated monolithically (to avoid
cudaHostRegister thrash).

* remove old pin path

* Implement JIT pinned memory pressure

Replace the predictive pin pressure mechanism with JIT PIN memory
pressure.

* LowVRAMPatch: change to two-phase visit

* lora: re-implement as inplace swiss-army-knife operation

* prepare for multiple pin sets

* implement pinned loras

* requirements: comfy-aimdo 0.4.0

* ops: remove unused arg

This was defeatured in aimdo iteration

* ops: sync the CPU with only the offload stream activity

This was syncing with the offload stream which itself is synced with the
compute stream, so this was syncing CPU with compute transitively. Define
the event to sync it more gently.

* pins: implement freeing intermediate for pinned memory

Pinning is more important than inactive intermediates and the stream
pin buffer is more important than even active intermediates.

* execution: implement pin eviction on RAM presure

Add back proper pin freeing on RAM pressure

* implement pin registration swaps

Uncap the windows pins from 50% by extending the pool and have a pressure
mechanism to move the pin reservations om demand.

This unfortunately implies a GPU sync to do the freeing so significant
hysterisis needs to be added to consolidate these pressure events.

* cli_args/execution: Implement lower background cache-ram threshold

Limit the amount of RAM background intermediates can use, so that
switching workflows doesn't degrade performance too much.

* make default

* bump aimdo

* model-patcher: force-cast tiny weights

Flux 2 gets crazy stalls due to a mix of tiny and giant weights
creating lopsided steam buffer rotations which creates stalls.

* ops: refactor in prep for chunking

* mm: delegate pin-on-the-way to aimdo

Aimdo is able to chunk and slice this on the way for better CPU->GPU
overlap. The main advantage is the ability to shorten the bus contention
window between previous weight transfer and the next weights vbar
fault.

* bump aimdo

* pinning updates

* specify hostbuf max allocation size

There a signs of virtual memory exhaustion on some linux systems when
throwing 128GB for every little piece. Pass the actual to save aimdo
from over-estimates

* tests: update execution tests for caching

The default caching changed to ram-cache so update these tests
accordingly.

Remove the LRU 0 test as this also falls through to RAM cache.
This commit is contained in:
rattus 2026-05-21 10:03:58 +10:00 committed by GitHub
parent 4d6a058bf1
commit 5aa5ccc9e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 408 additions and 219 deletions

View File

@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).")
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -245,6 +243,9 @@ if comfy.options.args_parsing:
else: else:
args = parser.parse_args([]) args = parser.parse_args([])
if args.cache_ram is not None and len(args.cache_ram) > 2:
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
if args.windows_standalone_build: if args.windows_standalone_build:
args.auto_launch = True args.auto_launch = True

View File

@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight return weight
def prefetch_prepared_value(value, allocate_buffer, stream): def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) size = comfy.memory_management.vram_aligned_size(value)
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0] return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase): elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple): elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list): elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value return value

View File

@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple):
size: int size: int
def read_tensor_file_slice_into(tensor, destination): def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor): if not isinstance(destination, QuantizedTensor):
@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination):
if tensor._layout_cls != destination._layout_cls: if tensor._layout_cls != destination._layout_cls:
return False return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False return False
dst_orig_dtype = destination._params.orig_dtype dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False) destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination):
if info.size == 0: if info.size == 0:
return True return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr())) view = memoryview(buf_type.from_address(destination.data_ptr()))
@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom):
extra_ram_release_callback = callback extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom)) RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target): def extra_ram_release(target, free_active=False):
if extra_ram_release_callback is None: if extra_ram_release_callback is None:
return 0 return 0
return extra_ram_release_callback(target) return extra_ram_release_callback(target, free_active=free_active)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer import comfy_aimdo.vram_buffer
class VRAMState(Enum): class VRAMState(Enum):
@ -495,6 +496,14 @@ except:
current_loaded_models = [] current_loaded_models = []
DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
#Freeing registerables on pressure does imply a GPU sync, so go big on
#the hysteresis so each expensive sync gives us back a good chunk.
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
sd = module.state_dict() sd = module.state_dict()
@ -503,27 +512,46 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False): def mark_mmap_dirty(storage):
mmap_touched_mem = 0 mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
module_mem = 0 if mmap_refs is not None:
bounced_mmaps = set() DIRTY_MMAPS.add(mmap_refs[0])
sd = module.state_dict()
for k in sd: def free_pins(size, evict_active=False):
t = sd[k] freed_total = 0
module_mem += t.nbytes for loaded_model in reversed(current_loaded_models):
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() if size <= 0:
if not getattr(storage, "_comfy_tensor_mmap_touched", False): return freed_total
continue model = loaded_model.model
mmap_touched_mem += t.nbytes if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
if not free: freed = model.partially_unload_ram(size)
continue freed_total += freed
storage._comfy_tensor_mmap_touched = False size -= freed
mmap_obj = storage._comfy_tensor_mmap_refs[0] return freed_total
if mmap_obj in bounced_mmaps:
continue def ensure_pin_budget(size, evict_active=False):
mmap_obj.bounce() shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
bounced_mmaps.add(mmap_obj) if shortfall <= 0:
return mmap_touched_mem, module_mem return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
return True
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
@ -553,9 +581,6 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS: if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None: if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -657,7 +676,6 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted: for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
pins_to_free = 1e32 if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram() if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size() memory_required -= current_loaded_models[i].model.loaded_size()
@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False) model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach() model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
device = loaded_model.device device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, free_memory(total_memory_required[device] * 1.1 + extra_mem,
device, device,
for_dynamic=free_for_dynamic, for_dynamic=free_for_dynamic)
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {} STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None: if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3))
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers(): def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None: if offload_stream is not None:
offload_stream.synchronize() offload_stream.synchronize()
synchronize() synchronize()
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):
@ -1280,7 +1317,7 @@ def sync_stream(device, stream):
current_stream(device).wait_stream(stream) current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
wf_context = nullcontext() wf_context = nullcontext()
if stream is not None: if stream is not None:
wf_context = stream wf_context = stream
@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = wf_context.as_context(stream) wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context: with wf_context:
for tensor in tensors: for tensor in tensors:
dest_view = dest_views.pop(0) dest_view = dest_views.pop(0)
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
if tensor is None: if tensor is None:
continue continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
continue continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"): mark_mmap_dirty(storage)
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1 MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory: if not args.disable_pinned_memory:
if is_nvidia() or is_amd(): if is_nvidia() or is_amd():
ram = get_total_memory(torch.device("cpu"))
if WINDOWS: if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
else: else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 MAX_PINNED_MEMORY = ram * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def pinned_hostbuf_size(size):
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
def discard_cuda_async_error(): def discard_cuda_async_error():
try: try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
@ -1378,8 +1422,8 @@ def pin_memory(tensor):
return False return False
size = tensor.nbytes size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
return False ensure_pin_registerable(size)
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
if ptr == 0: if ptr == 0:
@ -1416,7 +1460,8 @@ def unpin_memory(tensor):
return False return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) size = PINNED_MEMORY.pop(ptr)
TOTAL_PINNED_MEMORY -= size
return True return True
else: else:
logging.warning("Unpin error.") logging.warning("Unpin error.")

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.ops import comfy.ops
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -117,6 +118,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data) return comfy.utils.string_to_seed(data)
class LowVramPatch: class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
@ -124,11 +127,21 @@ class LowVramPatch:
self.set_func = set_func self.set_func = set_func
self.prepared_patches = None self.prepared_patches = None
def prepare(self, allocate_buffer, stream): def memory_required(self):
self.prepared_patches = [ counter = [0]
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key] for patch in self.patches[self.key]
] ]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self): def clear_prepared(self):
self.prepared_patches = None self.prepared_patches = None
@ -341,9 +354,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -1118,8 +1128,12 @@ class ModelPatcher:
# Pinned memory pressure tracking is only implemented for DynamicVram loading # Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0 return 0
def loaded_ram_size(self):
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass return 0
def detach(self, unpatch_all=True): def detach(self, unpatch_all=True):
self.eject_model() self.eject_model()
@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update) super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"): if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {} self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None self.non_dynamic_delegate_model = None
assert load_device is not None assert load_device is not None
@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher):
self.unpatch_hooks() self.unpatch_hooks()
vbar = self._vbar_get(create=True) vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None: if vbar is not None:
vbar.prioritize() vbar.prioritize()
@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0) return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1 num_patches += 1
else: else:
setattr(m, param_key + "_lowvram_function", None) setattr(m, param_key + "_lowvram_function", None)
@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher):
def force_load_param(self, param_key, device_to): def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key) key = key_param_name_to_key(n, param_key)
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return
if key in self.backup: if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True) self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty) set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight") #Models that mix tiny and giant weights can causing lopsided stream buffer
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") #rotations and stall. force the tinys over.
force_load = force_load or force_load_bias if module_mem > 16 * 1024:
v_weight_size += v_weight_bias force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
else:
force_load=True
if force_load: if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to) force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to) force_load_param(self, "bias", device_to)
else: else:
@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def pinned_memory_size(self): def loaded_ram_size(self):
total = 0 return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
loading = self._load_list(for_dynamic=True) self.model.dynamic_pins[self.load_device]["patches"][0].size)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def pinned_memory_size(self):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
for x in loading: self.model.dynamic_pins[self.load_device]["patches"][3][0])
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m) def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
if ram_to_unload <= 0: freed = 0
return pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
split -= 1
stack_split[0] = split
if not module._pin_registered:
continue
size = module._pin.numel() * module._pin.element_size()
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
comfy.model_management.discard_cuda_async_error()
continue
module._pin_registered = False
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
if module._pin_registered:
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of #This isn't used by the core at all and can only be to load a model out of

View File

@ -75,6 +75,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None offload_stream = None
cast_buffer = None cast_buffer = None
cast_buffer_offset = 0 cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest): def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream nonlocal offload_stream
@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size cast_buffer_offset += buffer_size
return buffer return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules: for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None: if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size) xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None: def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
comfy.pinned_memory.pin_memory(s) if xfer_source is not None:
pin = comfy.pinned_memory.get_pin(s) if getattr(xfer_source, "is_lowvram_patch", False):
else: xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
pin = None else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if pin is not None: def handle_pin(m, pin, source, dest, subset="weights", size=None):
comfy.model_management.cast_to_gathered(xfer_source, pin) if pin is not None:
xfer_source = [ pin ] cast_maybe_lowvram_patch([pin], dest, offload_stream)
#send it over return
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
for param_key in ("weight", "bias"): for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None: if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False) ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
prefetch["xfer_dest"] = xfer_dest prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest prefetch["cast_dest"] = cast_dest
@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream return offload_stream

View File

@ -2,42 +2,62 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import torch
from comfy.cli_args import args from comfy.cli_args import args
def get_pin(module): def get_pin(module, subset="weights"):
return getattr(module, "_pin", None) pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
def pin_memory(module): _, _, stack_split, pinned_size = module._pin_state[subset]
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return pin
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return pin
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if args.disable_pinned_memory:
return return
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
return
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: hostbuf, stack, stack_split, pinned_size = pin_state[subset]
module.pin_failed = True if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False return False
try: try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) hostbuf.extend(size=size)
except RuntimeError: except RuntimeError:
module.pin_failed = True pin_state["failed"] = True
return False return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin_hostbuf = hostbuf module._pin.untyped_storage()._comfy_hostbuf = hostbuf
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return True return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@ -113,7 +113,6 @@ def load_safetensors(ckpt):
"_comfy_tensor_file_slice", "_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res

View File

@ -1,52 +0,0 @@
import ctypes
import logging
import psutil
from ctypes import wintypes
import comfy_aimdo.control
psapi = ctypes.WinDLL("psapi")
kernel32 = ctypes.WinDLL("kernel32")
class PERFORMANCE_INFORMATION(ctypes.Structure):
_fields_ = [
("cb", wintypes.DWORD),
("CommitTotal", ctypes.c_size_t),
("CommitLimit", ctypes.c_size_t),
("CommitPeak", ctypes.c_size_t),
("PhysicalTotal", ctypes.c_size_t),
("PhysicalAvailable", ctypes.c_size_t),
("SystemCache", ctypes.c_size_t),
("KernelTotal", ctypes.c_size_t),
("KernelPaged", ctypes.c_size_t),
("KernelNonpaged", ctypes.c_size_t),
("PageSize", ctypes.c_size_t),
("HandleCount", wintypes.DWORD),
("ProcessCount", wintypes.DWORD),
("ThreadCount", wintypes.DWORD),
]
def get_free_ram():
#Windows is way too conservative and chalks recently used uncommitted model RAM
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
#
#1: What psutil says
#2: Total Memory - (Committed Memory - VRAM in use)
#
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
#commit charge for all VRAM used just incase it wants to page it all out. This just
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
pi = PERFORMANCE_INFORMATION()
pi.cb = ctypes.sizeof(pi)
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
return psutil.virtual_memory().available
committed = pi.CommitTotal * pi.PageSize
total = pi.PhysicalTotal * pi.PageSize
return max(psutil.virtual_memory().available,
total - (committed - comfy_aimdo.control.get_total_vram_usage()))

View File

@ -2,6 +2,7 @@ import copy
import heapq import heapq
import inspect import inspect
import logging import logging
import psutil
import sys import sys
import threading import threading
import time import time
@ -727,6 +728,7 @@ class PromptExecutor:
self._notify_prompt_lifecycle("start", prompt_id) self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
@ -780,8 +782,14 @@ class PromptExecutor:
execution_list.complete_node_execution() execution_list.complete_node_execution()
if self.cache_type == CacheType.RAM_PRESSURE: if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) ram_release_callback(ram_inactive_headroom)
ram_release_callback(ram_headroom, free_active=True) ram_shortfall = ram_headroom - psutil.virtual_memory().available
freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2))
if freed < ram_shortfall:
if freed > 64 * (1024 ** 2):
# AIMDO MEM_DECOMMIT can outrun psutil.available catching up.
time.sleep(0.05)
ram_release_callback(ram_headroom, free_active=True)
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed # Send cached UI for intermediate output nodes that weren't executed

20
main.py
View File

@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
cache_ram = args.cache_ram cache_ram = 0
if cache_ram < 0: cache_ram_inactive = 0
if not args.cache_classic and not args.cache_none and args.cache_lru <= 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0))
if len(args.cache_ram) > 0:
cache_ram = args.cache_ram[0]
if len(args.cache_ram) > 1:
cache_ram_inactive = args.cache_ram[1]
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.RAM_PRESSURE
if args.cache_lru > 0: if args.cache_classic:
cache_type = execution.CacheType.CLASSIC
elif args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo==0.3.0 comfy-aimdo==0.4.3
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3

View File

@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup
class TestAsyncNodes: class TestAsyncNodes:
@fixture(scope="class", autouse=True, params=[ @fixture(scope="class", autouse=True, params=[
(False, 0), (False, 0),
(True, 0),
(True, 100), (True, 100),
]) ])
def _server(self, args_pytest, request): def _server(self, args_pytest, request):
@ -29,6 +28,8 @@ class TestAsyncNodes:
use_lru, lru_size = request.param use_lru, lru_size = request.param
if use_lru: if use_lru:
pargs += ['--cache-lru', str(lru_size)] pargs += ['--cache-lru', str(lru_size)]
else:
pargs += ['--cache-classic']
# Running server with args: pargs # Running server with args: pargs
p = subprocess.Popen(pargs) p = subprocess.Popen(pargs)
yield yield

View File

@ -183,8 +183,7 @@ class TestExecution:
# Initialize server and client # Initialize server and client
# #
@fixture(scope="class", autouse=True, params=[ @fixture(scope="class", autouse=True, params=[
{ "extra_args" : [], "should_cache_results" : True }, { "extra_args" : ["--cache-classic"], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
{ "extra_args" : ["--cache-none"], "should_cache_results" : False }, { "extra_args" : ["--cache-none"], "should_cache_results" : False },
]) ])