This commit is contained in:
rattus 2026-05-08 21:27:31 -06:00 committed by GitHub
commit 5a5e493064
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 230 additions and 190 deletions

View File

@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight return weight
def prefetch_prepared_value(value, allocate_buffer, stream): def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) size = comfy.memory_management.vram_aligned_size(value)
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0] return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase): elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple): elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list): elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value return value

View File

@ -48,6 +48,12 @@ def read_tensor_file_slice_into(tensor, destination):
if info.size == 0: if info.size == 0:
return True return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address())
return True
buf_type = ctypes.c_ubyte * info.size buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr())) view = memoryview(buf_type.from_address(destination.data_ptr()))

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer import comfy_aimdo.vram_buffer
class VRAMState(Enum): class VRAMState(Enum):
@ -495,6 +496,10 @@ except:
current_loaded_models = [] current_loaded_models = []
DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
sd = module.state_dict() sd = module.state_dict()
@ -503,27 +508,26 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False): def mark_mmap_dirty(storage):
mmap_touched_mem = 0 mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
module_mem = 0 if mmap_refs is not None:
bounced_mmaps = set() DIRTY_MMAPS.add(mmap_refs[0])
sd = module.state_dict()
for k in sd: def ensure_pin_budget(size, evict_active=False):
t = sd[k] if MAX_PINNED_MEMORY <= 0:
module_mem += t.nbytes return
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False): shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
continue if shortfall <= 0:
mmap_touched_mem += t.nbytes return
if not free:
continue shortfall += PIN_PRESSURE_HYSTERESIS
storage._comfy_tensor_mmap_touched = False for loaded_model in reversed(current_loaded_models):
mmap_obj = storage._comfy_tensor_mmap_refs[0] model = loaded_model.model
if mmap_obj in bounced_mmaps: if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
continue shortfall -= model.partially_unload_ram(shortfall)
mmap_obj.bounce() if shortfall <= 0:
bounced_mmaps.add(mmap_obj) break
return mmap_touched_mem, module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
@ -553,9 +557,6 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -635,15 +636,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS: if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None: if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -657,7 +652,6 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
@ -673,11 +667,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted: for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
pins_to_free = 1e32 if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram() if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size() memory_required -= current_loaded_models[i].model.loaded_size()
@ -685,18 +677,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -762,29 +742,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False) model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach() model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
device = loaded_model.device device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, free_memory(total_memory_required[device] * 1.1 + extra_mem,
device, device,
for_dynamic=free_for_dynamic, for_dynamic=free_for_dynamic)
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1180,6 +1147,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {} STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1220,21 +1188,62 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None: if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
offload_stream.synchronize()
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
ensure_pin_budget(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers(): def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None: if offload_stream is not None:
offload_stream.synchronize() offload_stream.synchronize()
synchronize() synchronize()
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
if TOTAL_PINNED_MEMORY < 0:
TOTAL_PINNED_MEMORY = 0
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [])
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):
@ -1296,8 +1305,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"): mark_mmap_dirty(storage)
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
@ -1378,8 +1386,7 @@ def pin_memory(tensor):
return False return False
size = tensor.nbytes size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: ensure_pin_budget(size)
return False
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
if ptr == 0: if ptr == 0:

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.ops import comfy.ops
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -117,6 +118,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data) return comfy.utils.string_to_seed(data)
class LowVramPatch: class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
@ -124,11 +127,21 @@ class LowVramPatch:
self.set_func = set_func self.set_func = set_func
self.prepared_patches = None self.prepared_patches = None
def prepare(self, allocate_buffer, stream): def memory_required(self):
self.prepared_patches = [ counter = [0]
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key] for patch in self.patches[self.key]
] ]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self): def clear_prepared(self):
self.prepared_patches = None self.prepared_patches = None
@ -310,9 +323,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -1088,7 +1098,7 @@ class ModelPatcher:
return 0 return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass return 0
def detach(self, unpatch_all=True): def detach(self, unpatch_all=True):
self.eject_model() self.eject_model()
@ -1495,6 +1505,15 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update) super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"): if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {} self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
if self.load_device not in self.model.dynamic_pins:
self.model.dynamic_pins[self.load_device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []),
"failed": False,
"active": False,
}
self.non_dynamic_delegate_model = None self.non_dynamic_delegate_model = None
assert load_device is not None assert load_device is not None
@ -1556,6 +1575,9 @@ class ModelPatcherDynamic(ModelPatcher):
self.unpatch_hooks() self.unpatch_hooks()
vbar = self._vbar_get(create=True) vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None: if vbar is not None:
vbar.prioritize() vbar.prioritize()
@ -1581,7 +1603,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0) return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1 num_patches += 1
else: else:
setattr(m, param_key + "_lowvram_function", None) setattr(m, param_key + "_lowvram_function", None)
@ -1607,8 +1631,8 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty) set_dirty(m, dirty)
force_load, v_weight_size = setup_param(self, m, n, "weight") force_load, v_weight_size = setup_param(self, m, n, "weight")
@ -1686,22 +1710,27 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def pinned_memory_size(self): def pinned_memory_size(self):
total = 0 return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
loading = self._load_list(for_dynamic=True) self.model.dynamic_pins[self.load_device]["patches"][0].size)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) freed = 0
for x in loading: pin_state = self.model.dynamic_pins[self.load_device]
*_, m, _ = x for subset in subsets:
ram_to_unload -= comfy.pinned_memory.unpin_memory(m) hostbuf, stack = pin_state[subset]
if ram_to_unload <= 0: while len(stack) > 0:
return module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of #This isn't used by the core at all and can only be to load a model out of

View File

@ -75,6 +75,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None offload_stream = None
cast_buffer = None cast_buffer = None
cast_buffer_offset = 0 cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest): def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream nonlocal offload_stream
@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size cast_buffer_offset += buffer_size
return buffer return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules: for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,23 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None: if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size) xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None: def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
comfy.pinned_memory.pin_memory(s) if xfer_source is not None:
pin = comfy.pinned_memory.get_pin(s) if getattr(xfer_source, "is_lowvram_patch", False):
else: xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
pin = None else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if pin is not None: def handle_pin_miss(m, source, dest, subset="weights", size=None):
comfy.model_management.cast_to_gathered(xfer_source, pin) pin = None
xfer_source = [ pin ] if signature is None:
#send it over comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
cast_maybe_lowvram_patch(source, pin, None)
return [ pin ]
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return None
return source
if pin is None:
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
for param_key in ("weight", "bias"): for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None: if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False) ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
prefetch["xfer_dest"] = xfer_dest prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest prefetch["cast_dest"] = cast_dest
@ -186,6 +229,19 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
return offload_stream return offload_stream

View File

@ -5,39 +5,28 @@ import comfy_aimdo.torch
from comfy.cli_args import args from comfy.cli_args import args
def get_pin(module): def get_pin(module, subset="weights"):
return getattr(module, "_pin", None) return getattr(module, "_pin", None)
def pin_memory(module): def pin_memory(module, subset="weights", size=None):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: pin_state = module._pin_state
if pin_state["failed"] or args.disable_pinned_memory or get_pin(module, subset) is not None:
return return
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) hostbuf, stack = pin_state[subset]
if size is None:
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
module.pin_failed = True offset = hostbuf.size
return False comfy.model_management.ensure_pin_budget(size)
try: try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) hostbuf.extend(size=size)
except RuntimeError: except RuntimeError:
module.pin_failed = True pin_state["failed"] = True
return False return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin_hostbuf = hostbuf module._pin.untyped_storage()._comfy_hostbuf = hostbuf
stack.append((module, offset))
comfy.model_management.TOTAL_PINNED_MEMORY += size comfy.model_management.TOTAL_PINNED_MEMORY += size
return True return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@ -113,7 +113,6 @@ def load_safetensors(ckpt):
"_comfy_tensor_file_slice", "_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -1445,4 +1444,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res

View File

@ -1,52 +0,0 @@
import ctypes
import logging
import psutil
from ctypes import wintypes
import comfy_aimdo.control
psapi = ctypes.WinDLL("psapi")
kernel32 = ctypes.WinDLL("kernel32")
class PERFORMANCE_INFORMATION(ctypes.Structure):
_fields_ = [
("cb", wintypes.DWORD),
("CommitTotal", ctypes.c_size_t),
("CommitLimit", ctypes.c_size_t),
("CommitPeak", ctypes.c_size_t),
("PhysicalTotal", ctypes.c_size_t),
("PhysicalAvailable", ctypes.c_size_t),
("SystemCache", ctypes.c_size_t),
("KernelTotal", ctypes.c_size_t),
("KernelPaged", ctypes.c_size_t),
("KernelNonpaged", ctypes.c_size_t),
("PageSize", ctypes.c_size_t),
("HandleCount", wintypes.DWORD),
("ProcessCount", wintypes.DWORD),
("ThreadCount", wintypes.DWORD),
]
def get_free_ram():
#Windows is way too conservative and chalks recently used uncommitted model RAM
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
#
#1: What psutil says
#2: Total Memory - (Committed Memory - VRAM in use)
#
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
#commit charge for all VRAM used just incase it wants to page it all out. This just
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
pi = PERFORMANCE_INFORMATION()
pi.cb = ctypes.sizeof(pi)
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
return psutil.virtual_memory().available
committed = pi.CommitTotal * pi.PageSize
total = pi.PhysicalTotal * pi.PageSize
return max(psutil.virtual_memory().available,
total - (committed - comfy_aimdo.control.get_total_vram_usage()))

View File

@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo==0.3.0 comfy-aimdo==0.4.0
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3