diff --git a/comfy/lora.py b/comfy/lora.py index db8f16bcb..2b8699710 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if copy: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 48e3c11da..4a628b05c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -48,6 +48,12 @@ def read_tensor_file_slice_into(tensor, destination): if info.size == 0: return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) + if hostbuf is not None: + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address()) + return True + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..145a32080 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -495,6 +496,10 @@ except: current_loaded_models = [] +DIRTY_MMAPS = set() + +PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -503,27 +508,26 @@ def module_size(module): module_mem += t.nbytes return module_mem -def module_mmap_residency(module, free=False): - mmap_touched_mem = 0 - module_mem = 0 - bounced_mmaps = set() - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() - if not getattr(storage, "_comfy_tensor_mmap_touched", False): - continue - mmap_touched_mem += t.nbytes - if not free: - continue - storage._comfy_tensor_mmap_touched = False - mmap_obj = storage._comfy_tensor_mmap_refs[0] - if mmap_obj in bounced_mmaps: - continue - mmap_obj.bounce() - bounced_mmaps.add(mmap_obj) - return mmap_touched_mem, module_mem +def mark_mmap_dirty(storage): + mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None) + if mmap_refs is not None: + DIRTY_MMAPS.add(mmap_refs[0]) + +def ensure_pin_budget(size, evict_active=False): + if MAX_PINNED_MEMORY <= 0: + return + + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if shortfall <= 0: + return + + shortfall += PIN_PRESSURE_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.partially_unload_ram(shortfall) + if shortfall <= 0: + break class LoadedModel: def __init__(self, model): @@ -553,9 +557,6 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_mmap_residency(self, free=False): - return self.model.model_mmap_residency(free=free) - def model_loaded_memory(self): return self.model.loaded_size() @@ -635,15 +636,9 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: - import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 - def get_free_ram(): - return comfy.windows.get_free_ram() -else: - def get_free_ram(): - return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -657,7 +652,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() - comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -673,11 +667,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY or device is None: + if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - pins_to_free = pins_required - get_free_ram() - if current_loaded_models[i].model.is_dynamic() and for_dynamic: + if for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -685,18 +677,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) - - for x in can_unload_sorted: - i = x[-1] - ram_to_free = ram_required - psutil.virtual_memory().available - if ram_to_free <= 0 and i not in unloaded_model: - continue - resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) - if resident_memory > 0: - logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -762,29 +742,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() - total_memory_required = {} - total_pins_required = {} - total_ram_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) - resident_memory, model_memory = loaded_model.model_mmap_residency() - pinned_memory = loaded_model.model.pinned_memory_size() - #FIXME: This can over-free the pins as it budgets to pin the entire model. We should - #make this JIT to keep as much pinned as possible. - pins_required = model_memory - pinned_memory - ram_required = model_memory - resident_memory - total_pins_required[device] = total_pins_required.get(device, 0) + pins_required - total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic, - pins_required=total_pins_required[device], - ram_required=total_ram_required[device]) + for_dynamic=free_for_dynamic) for device in total_memory_required: if device != torch.device("cpu"): @@ -1180,6 +1147,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1220,21 +1188,62 @@ def get_aimdo_cast_buffer(offload_stream, device): if cast_buffer is None: cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer - return cast_buffer + +def get_pin_buffer(offload_stream): + pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) + if pin_buffer is None: + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0) + STREAM_PIN_BUFFERS[offload_stream] = pin_buffer + elif offload_stream is not None: + offload_stream.synchronize() + return pin_buffer + +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = pin_buffer.size + if size <= old_size: + return True + growth = size - old_size + ensure_pin_budget(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() + for mmap_obj in DIRTY_MMAPS: + mmap_obj.bounce() + DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= pin_buffer.size + if TOTAL_PINNED_MEMORY < 0: + TOTAL_PINNED_MEMORY = 0 + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []) + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() + STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1296,8 +1305,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() - if hasattr(storage, "_comfy_tensor_mmap_touched"): - storage._comfy_tensor_mmap_touched = True + mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) @@ -1378,8 +1386,7 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + ensure_pin_budget(size) ptr = tensor.data_ptr() if ptr == 0: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 33bdedfb1..6769dd01f 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -117,6 +118,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -124,11 +127,21 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) + return counter[0] + + def prepare(self, destination, stream, copy=True, commit=True): + counter = [0] + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None @@ -310,9 +323,6 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size - def model_mmap_residency(self, free=False): - return comfy.model_management.module_mmap_residency(self.model, free=free) - def loaded_size(self): return self.model.model_loaded_weight_memory @@ -1088,7 +1098,7 @@ class ModelPatcher: return 0 def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1495,6 +1505,15 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = { + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []), + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1556,6 +1575,9 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1581,7 +1603,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) @@ -1607,8 +1631,8 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) force_load, v_weight_size = setup_param(self, m, n, "weight") @@ -1686,22 +1710,27 @@ class ModelPatcherDynamic(ModelPatcher): return freed def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) - if ram_to_unload <= 0: - return + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset) + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index 77ad1d527..629b54e4c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -75,6 +75,8 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references +STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 + stream_pin_hostbuf = None + stream_pin_offset = 0 + stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer + def get_stream_pin_buffer_offset(buffer_size): + nonlocal stream_pin_hostbuf + nonlocal stream_pin_offset + + if buffer_size == 0 or offload_stream is None: + return None + + if stream_pin_hostbuf is None: + stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None + + offset = stream_pin_offset + stream_pin_offset += buffer_size + return offset + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -162,23 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if signature is None and pin is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - else: - pin = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - #send it over - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin_miss(m, source, dest, subset="weights", size=None): + pin = None + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + cast_maybe_lowvram_patch(source, pin, None) + return [ pin ] + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return None + return source + + if pin is None: + xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size) + + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_size = lowvram_source.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ] + + cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream) prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -186,6 +229,19 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + if stream_pin_offset > 0: + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + return offload_stream + stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset) + stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf + for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: + pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..208c777f8 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -5,39 +5,28 @@ import comfy_aimdo.torch from comfy.cli_args import args -def get_pin(module): +def get_pin(module, subset="weights"): return getattr(module, "_pin", None) -def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: +def pin_memory(module, subset="weights", size=None): + pin_state = module._pin_state + if pin_state["failed"] or args.disable_pinned_memory or get_pin(module, subset) is not None: return - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) - - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True - return False + hostbuf, stack = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + offset = hostbuf.size + comfy.model_management.ensure_pin_budget(size) try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + stack.append((module, offset)) comfy.model_management.TOTAL_PINNED_MEMORY += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size diff --git a/comfy/utils.py b/comfy/utils.py index 91e1ba3d3..048473a61 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -113,7 +113,6 @@ def load_safetensors(ckpt): "_comfy_tensor_file_slice", comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) - setattr(storage, "_comfy_tensor_mmap_touched", False) sd[name] = tensor return sd, header.get("__metadata__", {}), @@ -1445,4 +1444,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res - diff --git a/comfy/windows.py b/comfy/windows.py deleted file mode 100644 index 213dc481d..000000000 --- a/comfy/windows.py +++ /dev/null @@ -1,52 +0,0 @@ -import ctypes -import logging -import psutil -from ctypes import wintypes - -import comfy_aimdo.control - -psapi = ctypes.WinDLL("psapi") -kernel32 = ctypes.WinDLL("kernel32") - -class PERFORMANCE_INFORMATION(ctypes.Structure): - _fields_ = [ - ("cb", wintypes.DWORD), - ("CommitTotal", ctypes.c_size_t), - ("CommitLimit", ctypes.c_size_t), - ("CommitPeak", ctypes.c_size_t), - ("PhysicalTotal", ctypes.c_size_t), - ("PhysicalAvailable", ctypes.c_size_t), - ("SystemCache", ctypes.c_size_t), - ("KernelTotal", ctypes.c_size_t), - ("KernelPaged", ctypes.c_size_t), - ("KernelNonpaged", ctypes.c_size_t), - ("PageSize", ctypes.c_size_t), - ("HandleCount", wintypes.DWORD), - ("ProcessCount", wintypes.DWORD), - ("ThreadCount", wintypes.DWORD), - ] - -def get_free_ram(): - #Windows is way too conservative and chalks recently used uncommitted model RAM - #as "in-use". So, calculate free RAM for the sake of general use as the greater of: - # - #1: What psutil says - #2: Total Memory - (Committed Memory - VRAM in use) - # - #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked - #commit charge for all VRAM used just incase it wants to page it all out. This just - #isn't realistic so "overcommit" on our calculations by just subtracting it off. - - pi = PERFORMANCE_INFORMATION() - pi.cb = ctypes.sizeof(pi) - - if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): - logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") - return psutil.virtual_memory().available - - committed = pi.CommitTotal * pi.PageSize - total = pi.PhysicalTotal * pi.PageSize - - return max(psutil.virtual_memory().available, - total - (committed - comfy_aimdo.control.get_total_vram_usage())) - diff --git a/requirements.txt b/requirements.txt index 5c7ff76be..3d0d2e6d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.3.0 +comfy-aimdo==0.4.0 requests simpleeval>=1.0.0 blake3