mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge 8819577433 into 8b08bfdcbe
This commit is contained in:
commit
5a5e493064
@ -475,16 +475,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
def prefetch_prepared_value(value, counter, destination, stream, copy):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
size = comfy.memory_management.vram_aligned_size(value)
|
||||
offset = counter[0]
|
||||
counter[0] += size
|
||||
if destination is None:
|
||||
return value
|
||||
|
||||
dest = destination[offset:offset + size]
|
||||
if copy:
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -48,6 +48,12 @@ def read_tensor_file_slice_into(tensor, destination):
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
hostbuf.read_file_slice(file_obj, info.offset, info.size,
|
||||
offset=destination.data_ptr() - hostbuf.get_raw_address())
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
@ -495,6 +496,10 @@ except:
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
DIRTY_MMAPS = set()
|
||||
|
||||
PIN_PRESSURE_HYSTERESIS = 128 * 1024 * 1024
|
||||
|
||||
def module_size(module):
|
||||
module_mem = 0
|
||||
sd = module.state_dict()
|
||||
@ -503,27 +508,26 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
def mark_mmap_dirty(storage):
|
||||
mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
|
||||
if mmap_refs is not None:
|
||||
DIRTY_MMAPS.add(mmap_refs[0])
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return
|
||||
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
if shortfall <= 0:
|
||||
return
|
||||
|
||||
shortfall += PIN_PRESSURE_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
shortfall -= model.partially_unload_ram(shortfall)
|
||||
if shortfall <= 0:
|
||||
break
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
@ -553,9 +557,6 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@ -635,15 +636,9 @@ WINDOWS = any(platform.win32_ver())
|
||||
|
||||
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
|
||||
if WINDOWS:
|
||||
import comfy.windows
|
||||
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
|
||||
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
|
||||
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
|
||||
def get_free_ram():
|
||||
return comfy.windows.get_free_ram()
|
||||
else:
|
||||
def get_free_ram():
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
@ -657,7 +652,6 @@ def minimum_inference_memory():
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -673,11 +667,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
if for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
@ -685,18 +677,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -762,29 +742,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model.detach(unpatch_all=False)
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
for_dynamic=free_for_dynamic)
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -1180,6 +1147,7 @@ STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_PIN_BUFFERS = {}
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
@ -1220,21 +1188,62 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_pin_buffer(offload_stream):
|
||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||
if pin_buffer is None:
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0)
|
||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||
elif offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
return pin_buffer
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = pin_buffer.size
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
ensure_pin_budget(growth, evict_active=True)
|
||||
try:
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
global TOTAL_PINNED_MEMORY
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
for mmap_obj in DIRTY_MMAPS:
|
||||
mmap_obj.bounce()
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
||||
if TOTAL_PINNED_MEMORY < 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), [])
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
STREAM_PIN_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1296,8 +1305,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
mark_mmap_dirty(storage)
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
@ -1378,8 +1386,7 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
size = tensor.nbytes
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
ensure_pin_budget(size)
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
|
||||
@ -35,6 +35,7 @@ import comfy.model_management
|
||||
import comfy.ops
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
@ -117,6 +118,8 @@ def string_to_seed(data):
|
||||
return comfy.utils.string_to_seed(data)
|
||||
|
||||
class LowVramPatch:
|
||||
is_lowvram_patch = True
|
||||
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
@ -124,11 +127,21 @@ class LowVramPatch:
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
def memory_required(self):
|
||||
counter = [0]
|
||||
for patch in self.patches[self.key]:
|
||||
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
|
||||
return counter[0]
|
||||
|
||||
def prepare(self, destination, stream, copy=True, commit=True):
|
||||
counter = [0]
|
||||
prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
if commit:
|
||||
self.prepared_patches = prepared_patches
|
||||
return prepared_patches
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
@ -310,9 +323,6 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@ -1088,7 +1098,7 @@ class ModelPatcher:
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
return 0
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
@ -1495,6 +1505,15 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
|
||||
if not hasattr(self.model, "dynamic_vbars"):
|
||||
self.model.dynamic_vbars = {}
|
||||
if not hasattr(self.model, "dynamic_pins"):
|
||||
self.model.dynamic_pins = {}
|
||||
if self.load_device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[self.load_device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024), []),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024), []),
|
||||
"failed": False,
|
||||
"active": False,
|
||||
}
|
||||
self.non_dynamic_delegate_model = None
|
||||
assert load_device is not None
|
||||
|
||||
@ -1556,6 +1575,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
self.unpatch_hooks()
|
||||
|
||||
vbar = self._vbar_get(create=True)
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
@ -1581,7 +1603,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
lowvram_patch = LowVramPatch(key, self.patches)
|
||||
lowvram_patch._pin_state = pin_state
|
||||
setattr(m, param_key + "_lowvram_function", lowvram_patch)
|
||||
num_patches += 1
|
||||
else:
|
||||
setattr(m, param_key + "_lowvram_function", None)
|
||||
@ -1607,8 +1631,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
m.pin_failed = False
|
||||
m.seed_key = n
|
||||
m._pin_state = pin_state
|
||||
set_dirty(m, dirty)
|
||||
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
@ -1686,22 +1710,27 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return freed
|
||||
|
||||
def pinned_memory_size(self):
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][0].size)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
*_, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
if ram_to_unload <= 0:
|
||||
return
|
||||
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack = pin_state[subset]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
del module._pin
|
||||
hostbuf.truncate(offset)
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
freed += size
|
||||
ram_to_unload -= size
|
||||
if ram_to_unload <= 0:
|
||||
return freed
|
||||
return freed
|
||||
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
#This isn't used by the core at all and can only be to load a model out of
|
||||
|
||||
82
comfy/ops.py
82
comfy/ops.py
@ -75,6 +75,8 @@ except:
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
offload_stream = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
stream_pin_hostbuf = None
|
||||
stream_pin_offset = 0
|
||||
stream_pin_queue = []
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
def get_stream_pin_buffer_offset(buffer_size):
|
||||
nonlocal stream_pin_hostbuf
|
||||
nonlocal stream_pin_offset
|
||||
|
||||
if buffer_size == 0 or offload_stream is None:
|
||||
return None
|
||||
|
||||
if stream_pin_hostbuf is None:
|
||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||
if stream_pin_hostbuf is None:
|
||||
return None
|
||||
|
||||
offset = stream_pin_offset
|
||||
stream_pin_offset += buffer_size
|
||||
return offset
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
@ -162,23 +183,45 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
pin = comfy.pinned_memory.get_pin(s)
|
||||
else:
|
||||
pin = None
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
|
||||
if pin is not None:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin)
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
def handle_pin_miss(m, source, dest, subset="weights", size=None):
|
||||
pin = None
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
return [ pin ]
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return None
|
||||
return source
|
||||
|
||||
if pin is None:
|
||||
xfer_source = handle_pin_miss(s, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_source = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_source is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
lowvram_size = lowvram_source.memory_required()
|
||||
lowvram_dest = get_cast_buffer(lowvram_size)
|
||||
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
|
||||
|
||||
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
|
||||
lowvram_source = handle_pin_miss(lowvram_source, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) if pin is None else [ pin ]
|
||||
|
||||
cast_maybe_lowvram_patch(lowvram_source, lowvram_dest, offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
@ -186,6 +229,19 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf, size=stream_pin_offset)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
|
||||
@ -5,39 +5,28 @@ import comfy_aimdo.torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def get_pin(module):
|
||||
def get_pin(module, subset="weights"):
|
||||
return getattr(module, "_pin", None)
|
||||
|
||||
def pin_memory(module):
|
||||
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
|
||||
def pin_memory(module, subset="weights", size=None):
|
||||
pin_state = module._pin_state
|
||||
if pin_state["failed"] or args.disable_pinned_memory or get_pin(module, subset) is not None:
|
||||
return
|
||||
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
hostbuf, stack = pin_state[subset]
|
||||
if size is None:
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
comfy.model_management.ensure_pin_budget(size)
|
||||
|
||||
try:
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
hostbuf.extend(size=size)
|
||||
except RuntimeError:
|
||||
module.pin_failed = True
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
stack.append((module, offset))
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@ -113,7 +113,6 @@ def load_safetensors(ckpt):
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
@ -1445,4 +1444,3 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
import ctypes
|
||||
import logging
|
||||
import psutil
|
||||
from ctypes import wintypes
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
psapi = ctypes.WinDLL("psapi")
|
||||
kernel32 = ctypes.WinDLL("kernel32")
|
||||
|
||||
class PERFORMANCE_INFORMATION(ctypes.Structure):
|
||||
_fields_ = [
|
||||
("cb", wintypes.DWORD),
|
||||
("CommitTotal", ctypes.c_size_t),
|
||||
("CommitLimit", ctypes.c_size_t),
|
||||
("CommitPeak", ctypes.c_size_t),
|
||||
("PhysicalTotal", ctypes.c_size_t),
|
||||
("PhysicalAvailable", ctypes.c_size_t),
|
||||
("SystemCache", ctypes.c_size_t),
|
||||
("KernelTotal", ctypes.c_size_t),
|
||||
("KernelPaged", ctypes.c_size_t),
|
||||
("KernelNonpaged", ctypes.c_size_t),
|
||||
("PageSize", ctypes.c_size_t),
|
||||
("HandleCount", wintypes.DWORD),
|
||||
("ProcessCount", wintypes.DWORD),
|
||||
("ThreadCount", wintypes.DWORD),
|
||||
]
|
||||
|
||||
def get_free_ram():
|
||||
#Windows is way too conservative and chalks recently used uncommitted model RAM
|
||||
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
|
||||
#
|
||||
#1: What psutil says
|
||||
#2: Total Memory - (Committed Memory - VRAM in use)
|
||||
#
|
||||
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
|
||||
#commit charge for all VRAM used just incase it wants to page it all out. This just
|
||||
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
|
||||
|
||||
pi = PERFORMANCE_INFORMATION()
|
||||
pi.cb = ctypes.sizeof(pi)
|
||||
|
||||
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
|
||||
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
|
||||
return psutil.virtual_memory().available
|
||||
|
||||
committed = pi.CommitTotal * pi.PageSize
|
||||
total = pi.PhysicalTotal * pi.PageSize
|
||||
|
||||
return max(psutil.virtual_memory().available,
|
||||
total - (committed - comfy_aimdo.control.get_total_vram_usage()))
|
||||
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo==0.3.0
|
||||
comfy-aimdo==0.4.0
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
Loading…
Reference in New Issue
Block a user