mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 19:42:34 +08:00
Fix disk-weight tiering and meta handling
This commit is contained in:
parent
95ca11fe25
commit
c3eaea0429
@ -179,7 +179,24 @@ def configure(*, allow_gds: bool, pin_if_cpu: bool, ram_headroom_bytes: int, ena
|
|||||||
PIN_IF_CPU = pin_if_cpu
|
PIN_IF_CPU = pin_if_cpu
|
||||||
DISK_WEIGHTS_ENABLED = enabled
|
DISK_WEIGHTS_ENABLED = enabled
|
||||||
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
RAM_HEADROOM_BYTES = max(0, int(ram_headroom_bytes))
|
||||||
CACHE.set_limit(0 if enabled else 0)
|
if enabled:
|
||||||
|
from . import model_management
|
||||||
|
cpu_capacity_bytes = max(0, model_management.get_total_memory(torch.device("cpu")) - RAM_HEADROOM_BYTES)
|
||||||
|
CACHE.set_limit(cpu_capacity_bytes)
|
||||||
|
LOGGER.debug(
|
||||||
|
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||||
|
enabled,
|
||||||
|
RAM_HEADROOM_BYTES,
|
||||||
|
cpu_capacity_bytes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
CACHE.set_limit(0)
|
||||||
|
LOGGER.debug(
|
||||||
|
"Disk weights configured: enabled=%s ram_headroom_bytes=%d cpu_capacity_bytes=%d",
|
||||||
|
enabled,
|
||||||
|
RAM_HEADROOM_BYTES,
|
||||||
|
0,
|
||||||
|
)
|
||||||
if enabled:
|
if enabled:
|
||||||
install_monkeypatches()
|
install_monkeypatches()
|
||||||
else:
|
else:
|
||||||
@ -449,12 +466,14 @@ def _ensure_free_memory(device: torch.device, required_bytes: int, headroom_byte
|
|||||||
)
|
)
|
||||||
safetensors_stream._reap_pinned_inflight()
|
safetensors_stream._reap_pinned_inflight()
|
||||||
from . import model_management
|
from . import model_management
|
||||||
|
model_management._reap_pinned_inflight()
|
||||||
model_management.free_memory(required_bytes + headroom_bytes, device)
|
model_management.free_memory(required_bytes + headroom_bytes, device)
|
||||||
free_after = _device_free_memory(device)
|
free_after = _device_free_memory(device)
|
||||||
freed = max(0, free_after - free_before)
|
freed = max(0, free_after - free_before)
|
||||||
LOGGER.debug(
|
LOGGER.debug(
|
||||||
"Disk weight memory freed: freed=%d free=%d device=%s",
|
"Disk weight memory freed: freed=%d free_before=%d free_after=%d device=%s",
|
||||||
freed,
|
freed,
|
||||||
|
free_before,
|
||||||
free_after,
|
free_after,
|
||||||
device,
|
device,
|
||||||
)
|
)
|
||||||
@ -840,6 +859,8 @@ def evict_ram_cache(bytes_to_free: int):
|
|||||||
if bytes_to_free <= 0:
|
if bytes_to_free <= 0:
|
||||||
return 0
|
return 0
|
||||||
safetensors_stream._reap_pinned_inflight()
|
safetensors_stream._reap_pinned_inflight()
|
||||||
|
from . import model_management
|
||||||
|
model_management._reap_pinned_inflight()
|
||||||
return CACHE.evict_bytes(bytes_to_free)
|
return CACHE.evict_bytes(bytes_to_free)
|
||||||
|
|
||||||
|
|
||||||
@ -926,22 +947,41 @@ def module_to(
|
|||||||
if target_device is None:
|
if target_device is None:
|
||||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||||
if target_device.type == "meta":
|
if target_device.type == "meta":
|
||||||
offload_module_weights(module)
|
for submodule in module.modules():
|
||||||
|
offload_module_weights(submodule)
|
||||||
return module
|
return module
|
||||||
if allow_materialize:
|
|
||||||
materialize_module_tree(module, target_device)
|
|
||||||
base_kwargs = dict(kwargs)
|
|
||||||
if device is not None and arg_device is None:
|
|
||||||
base_kwargs["device"] = device
|
|
||||||
if dtype is not None and arg_dtype is None:
|
|
||||||
base_kwargs["dtype"] = dtype
|
|
||||||
if non_blocking:
|
|
||||||
base_kwargs["non_blocking"] = non_blocking
|
|
||||||
if memory_format is not None:
|
|
||||||
base_kwargs["memory_format"] = memory_format
|
|
||||||
return BASE_MODULE_TO(module, *args, **base_kwargs)
|
|
||||||
dtype_override = dtype or arg_dtype
|
dtype_override = dtype or arg_dtype
|
||||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
to_kwargs = {}
|
||||||
|
if non_blocking:
|
||||||
|
to_kwargs["non_blocking"] = non_blocking
|
||||||
|
if memory_format is not None:
|
||||||
|
to_kwargs["memory_format"] = memory_format
|
||||||
|
for submodule in module.modules():
|
||||||
|
ensure_module_materialized(submodule, target_device, dtype_override=dtype_override)
|
||||||
|
refs = REGISTRY.get(submodule) or {}
|
||||||
|
for name, param in submodule.named_parameters(recurse=False):
|
||||||
|
if name in refs:
|
||||||
|
continue
|
||||||
|
if param is None or param.device.type == "meta":
|
||||||
|
continue
|
||||||
|
if param.device != target_device or (dtype_override is not None and param.dtype != dtype_override):
|
||||||
|
if dtype_override is not None:
|
||||||
|
tensor = param.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||||
|
else:
|
||||||
|
tensor = param.to(device=target_device, **to_kwargs)
|
||||||
|
submodule._parameters[name] = torch.nn.Parameter(tensor, requires_grad=param.requires_grad)
|
||||||
|
for name, buf in submodule.named_buffers(recurse=False):
|
||||||
|
if name in refs:
|
||||||
|
continue
|
||||||
|
if buf is None or buf.device.type == "meta":
|
||||||
|
continue
|
||||||
|
if buf.device != target_device or (dtype_override is not None and buf.dtype != dtype_override):
|
||||||
|
if dtype_override is not None:
|
||||||
|
tensor = buf.to(device=target_device, dtype=dtype_override, **to_kwargs)
|
||||||
|
else:
|
||||||
|
tensor = buf.to(device=target_device, **to_kwargs)
|
||||||
|
submodule._buffers[name] = tensor
|
||||||
|
return module
|
||||||
base_kwargs = dict(kwargs)
|
base_kwargs = dict(kwargs)
|
||||||
if device is not None and arg_device is None:
|
if device is not None and arg_device is None:
|
||||||
base_kwargs["device"] = device
|
base_kwargs["device"] = device
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
|
import collections
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import torch
|
import torch
|
||||||
@ -524,18 +525,14 @@ class LoadedModel:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
def model_unload(self, memory_to_free=None, unpatch_weights=True, offload_device=None):
|
||||||
|
target_offload_device = self.model.offload_device if offload_device is None else offload_device
|
||||||
if memory_to_free is not None:
|
if memory_to_free is not None:
|
||||||
if memory_to_free < self.model.loaded_size():
|
if memory_to_free < self.model.loaded_size():
|
||||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
freed = self.model.partially_unload(target_offload_device, memory_to_free)
|
||||||
if freed >= memory_to_free:
|
if freed >= memory_to_free:
|
||||||
return False
|
return False
|
||||||
offload_device = None
|
self.model.detach(unpatch_weights, offload_device=target_offload_device)
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
offload_device = torch.device("meta")
|
|
||||||
self.model.detach(unpatch_weights, offload_device=offload_device)
|
|
||||||
if offload_device is not None and offload_device.type == "meta":
|
|
||||||
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
|
||||||
self.model_finalizer.detach()
|
self.model_finalizer.detach()
|
||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@ -589,27 +586,30 @@ def minimum_inference_memory():
|
|||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
free_before = get_free_memory(device)
|
free_before = get_free_memory(device)
|
||||||
headroom = comfy.disk_weights.ram_headroom_bytes()
|
if is_device_cpu(device):
|
||||||
if free_before < memory_required:
|
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||||
logging.debug(
|
if free_before < memory_required:
|
||||||
"RAM pressure: required=%d free=%d headroom=%d",
|
logging.debug(
|
||||||
memory_required,
|
"Disk weights RAM pressure: required=%d free=%d headroom=%d device=%s",
|
||||||
free_before,
|
memory_required,
|
||||||
headroom,
|
free_before,
|
||||||
)
|
headroom,
|
||||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
device,
|
||||||
freed_disk = 0
|
)
|
||||||
if freed_cache < memory_required:
|
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||||
freed_disk = evict_ram_to_disk(memory_required - freed_cache)
|
if freed_cache < memory_required:
|
||||||
free_after = get_free_memory(device)
|
evict_ram_to_disk(memory_required - freed_cache, keep_loaded=keep_loaded)
|
||||||
freed_total = max(0, free_after - free_before)
|
elif is_device_cuda(device) or is_device_xpu(device):
|
||||||
logging.debug(
|
if free_before < memory_required:
|
||||||
"RAM freed: freed=%d free=%d",
|
logging.debug(
|
||||||
freed_total if freed_total > 0 else freed_cache + freed_disk,
|
"Disk weights VRAM pressure: required=%d free=%d device=%s",
|
||||||
free_after,
|
memory_required,
|
||||||
)
|
free_before,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
_reap_pinned_inflight()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
@ -630,7 +630,10 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
break
|
break
|
||||||
memory_to_free = memory_required - free_mem
|
memory_to_free = memory_required - free_mem
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
if current_loaded_models[i].model_unload(memory_to_free):
|
offload_device = None
|
||||||
|
if comfy.disk_weights.disk_weights_enabled() and is_device_cpu(device):
|
||||||
|
offload_device = torch.device("meta")
|
||||||
|
if current_loaded_models[i].model_unload(memory_to_free, offload_device=offload_device):
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
@ -643,6 +646,16 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
|
||||||
if mem_free_torch > mem_free_total * 0.25:
|
if mem_free_torch > mem_free_total * 0.25:
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
free_after = get_free_memory(device)
|
||||||
|
freed_total = max(0, free_after - free_before)
|
||||||
|
logging.debug(
|
||||||
|
"Disk weights free_memory: device=%s free_before=%d free_after=%d freed=%d",
|
||||||
|
device,
|
||||||
|
free_before,
|
||||||
|
free_after,
|
||||||
|
freed_total,
|
||||||
|
)
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
@ -826,8 +839,6 @@ def dtype_size(dtype):
|
|||||||
return dtype_size
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
return torch.device("meta")
|
|
||||||
if vram_state == VRAMState.HIGH_VRAM:
|
if vram_state == VRAMState.HIGH_VRAM:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -932,8 +943,6 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
return torch.device("meta")
|
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -994,8 +1003,6 @@ def vae_device():
|
|||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
|
|
||||||
def vae_offload_device():
|
def vae_offload_device():
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
|
||||||
return torch.device("meta")
|
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
@ -1175,6 +1182,15 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
else:
|
else:
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
|
if (
|
||||||
|
non_blocking
|
||||||
|
and is_device_cuda(device)
|
||||||
|
and is_device_cpu(weight.device)
|
||||||
|
and weight.is_pinned()
|
||||||
|
):
|
||||||
|
record_stream = stream if stream is not None else current_stream(device)
|
||||||
|
if record_stream is not None:
|
||||||
|
_track_pinned_inflight(record_stream, weight)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
@ -1185,6 +1201,8 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
PINNED_MEMORY = {}
|
PINNED_MEMORY = {}
|
||||||
TOTAL_PINNED_MEMORY = 0
|
TOTAL_PINNED_MEMORY = 0
|
||||||
MAX_PINNED_MEMORY = -1
|
MAX_PINNED_MEMORY = -1
|
||||||
|
PINNED_INFLIGHT = collections.deque()
|
||||||
|
DEFERRED_UNPIN = collections.deque()
|
||||||
if not args.disable_pinned_memory:
|
if not args.disable_pinned_memory:
|
||||||
if is_nvidia() or is_amd():
|
if is_nvidia() or is_amd():
|
||||||
if WINDOWS:
|
if WINDOWS:
|
||||||
@ -1252,7 +1270,19 @@ def pin_memory(tensor):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def unpin_memory(tensor):
|
def _track_pinned_inflight(stream, tensor):
|
||||||
|
event = torch.cuda.Event()
|
||||||
|
event.record(stream)
|
||||||
|
PINNED_INFLIGHT.append((event, tensor))
|
||||||
|
|
||||||
|
def _tensor_inflight(tensor):
|
||||||
|
ptr = tensor.data_ptr()
|
||||||
|
for _, inflight_tensor in PINNED_INFLIGHT:
|
||||||
|
if inflight_tensor.data_ptr() == ptr:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _unpin_memory_now(tensor):
|
||||||
global TOTAL_PINNED_MEMORY
|
global TOTAL_PINNED_MEMORY
|
||||||
if MAX_PINNED_MEMORY <= 0:
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
@ -1283,6 +1313,46 @@ def unpin_memory(tensor):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _retry_deferred_unpins():
|
||||||
|
if not DEFERRED_UNPIN:
|
||||||
|
return
|
||||||
|
remaining = collections.deque()
|
||||||
|
while DEFERRED_UNPIN:
|
||||||
|
tensor = DEFERRED_UNPIN.popleft()
|
||||||
|
if _tensor_inflight(tensor):
|
||||||
|
remaining.append(tensor)
|
||||||
|
continue
|
||||||
|
if not _unpin_memory_now(tensor):
|
||||||
|
remaining.append(tensor)
|
||||||
|
DEFERRED_UNPIN.extend(remaining)
|
||||||
|
|
||||||
|
def _reap_pinned_inflight():
|
||||||
|
if not PINNED_INFLIGHT:
|
||||||
|
_retry_deferred_unpins()
|
||||||
|
return
|
||||||
|
remaining = collections.deque()
|
||||||
|
while PINNED_INFLIGHT:
|
||||||
|
event, tensor = PINNED_INFLIGHT.popleft()
|
||||||
|
if event.query():
|
||||||
|
continue
|
||||||
|
remaining.append((event, tensor))
|
||||||
|
PINNED_INFLIGHT.extend(remaining)
|
||||||
|
_retry_deferred_unpins()
|
||||||
|
|
||||||
|
def unpin_memory(tensor):
|
||||||
|
global TOTAL_PINNED_MEMORY
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
_reap_pinned_inflight()
|
||||||
|
if _tensor_inflight(tensor):
|
||||||
|
DEFERRED_UNPIN.append(tensor)
|
||||||
|
return False
|
||||||
|
return _unpin_memory_now(tensor)
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
|
|||||||
@ -621,6 +621,16 @@ class ModelPatcher:
|
|||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
inplace_update = self.weight_inplace_update or inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
|
if comfy.disk_weights.disk_weights_enabled() and weight is not None and weight.device.type == "meta":
|
||||||
|
parts = key.split(".")
|
||||||
|
param_name = parts[-1]
|
||||||
|
module = self.model
|
||||||
|
for part in parts[:-1]:
|
||||||
|
module = getattr(module, part)
|
||||||
|
target_device = device_to or self.offload_device or torch.device("cpu")
|
||||||
|
comfy.disk_weights.load_module_tensor(module, param_name, device=target_device)
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
@ -650,8 +660,8 @@ class ModelPatcher:
|
|||||||
def unpin_weight(self, key):
|
def unpin_weight(self, key):
|
||||||
if key in self.pinned:
|
if key in self.pinned:
|
||||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
comfy.model_management.unpin_memory(weight)
|
if comfy.model_management.unpin_memory(weight):
|
||||||
self.pinned.remove(key)
|
self.pinned.remove(key)
|
||||||
|
|
||||||
def unpin_all_weights(self):
|
def unpin_all_weights(self):
|
||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
@ -885,9 +895,16 @@ class ModelPatcher:
|
|||||||
NS = comfy.model_management.NUM_STREAMS
|
NS = comfy.model_management.NUM_STREAMS
|
||||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||||
remaining_ram = None
|
remaining_ram = None
|
||||||
|
cpu_device = torch.device("cpu")
|
||||||
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
|
if device_to is not None and comfy.model_management.is_device_cpu(device_to):
|
||||||
remaining_ram = comfy.model_management.get_free_memory(device_to)
|
remaining_ram = comfy.model_management.get_free_memory(device_to)
|
||||||
|
|
||||||
|
def offload_module_tree(module):
|
||||||
|
freed = 0
|
||||||
|
for submodule in module.modules():
|
||||||
|
freed += comfy.disk_weights.offload_module_weights(submodule)
|
||||||
|
return freed
|
||||||
|
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||||
break
|
break
|
||||||
@ -922,15 +939,28 @@ class ModelPatcher:
|
|||||||
cast_weight = self.force_cast_weights
|
cast_weight = self.force_cast_weights
|
||||||
freed_bytes = module_mem
|
freed_bytes = module_mem
|
||||||
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
freed_bytes = offload_module_tree(m)
|
||||||
if freed_bytes == 0:
|
if freed_bytes == 0:
|
||||||
freed_bytes = module_mem
|
freed_bytes = module_mem
|
||||||
else:
|
else:
|
||||||
if remaining_ram is not None and remaining_ram < module_mem and comfy.disk_weights.disk_weights_enabled():
|
if remaining_ram is not None and comfy.disk_weights.disk_weights_enabled():
|
||||||
logging.info("Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.", n, module_mem / (1024 * 1024), remaining_ram / (1024 * 1024))
|
required_bytes = module_mem
|
||||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
headroom = comfy.disk_weights.ram_headroom_bytes()
|
||||||
if freed_bytes == 0:
|
comfy.model_management.free_memory(required_bytes + headroom, cpu_device, keep_loaded=[self])
|
||||||
freed_bytes = module_mem
|
remaining_ram = comfy.model_management.get_free_memory(cpu_device)
|
||||||
|
if remaining_ram < required_bytes:
|
||||||
|
logging.info(
|
||||||
|
"Insufficient CPU RAM for %s (need %.2f MB, free %.2f MB); offloading to disk.",
|
||||||
|
n,
|
||||||
|
required_bytes / (1024 * 1024),
|
||||||
|
remaining_ram / (1024 * 1024),
|
||||||
|
)
|
||||||
|
freed_bytes = offload_module_tree(m)
|
||||||
|
if freed_bytes == 0:
|
||||||
|
freed_bytes = module_mem
|
||||||
|
else:
|
||||||
|
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||||
|
remaining_ram = max(0, remaining_ram - required_bytes)
|
||||||
else:
|
else:
|
||||||
if comfy.disk_weights.disk_weights_enabled():
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||||
@ -980,16 +1010,22 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||||
|
offload_device = self.offload_device
|
||||||
|
if comfy.disk_weights.disk_weights_enabled() and device_to is not None:
|
||||||
|
if comfy.model_management.is_device_cpu(device_to):
|
||||||
|
offload_device = torch.device("meta")
|
||||||
|
else:
|
||||||
|
offload_device = torch.device("cpu")
|
||||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||||
# TODO: force_patch_weights should not unload + reload full model
|
# TODO: force_patch_weights should not unload + reload full model
|
||||||
used = self.model.model_loaded_weight_memory
|
used = self.model.model_loaded_weight_memory
|
||||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
self.unpatch_model(offload_device, unpatch_weights=unpatch_weights)
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||||
|
|
||||||
self.patch_model(load_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
if extra_memory < 0 and not unpatch_weights:
|
if extra_memory < 0 and not unpatch_weights:
|
||||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
self.partially_unload(offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||||
return 0
|
return 0
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||||
|
|||||||
32
comfy/ops.py
32
comfy/ops.py
@ -535,10 +535,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
key = f"{prefix}{param_name}"
|
key = f"{prefix}{param_name}"
|
||||||
value = state_dict.pop(key, None)
|
value = state_dict.pop(key, None)
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if value.device.type != "meta":
|
value = value.to(device=device)
|
||||||
value = value.to(device=device)
|
if dtype is not None:
|
||||||
if dtype is not None:
|
value = value.view(dtype=dtype)
|
||||||
value = value.view(dtype=dtype)
|
|
||||||
manually_loaded_keys.append(key)
|
manually_loaded_keys.append(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -555,16 +554,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
manually_loaded_keys = [weight_key]
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
if layer_conf is not None and layer_conf.device.type != "meta":
|
if layer_conf is not None:
|
||||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
elif layer_conf is not None:
|
|
||||||
layer_conf = None
|
|
||||||
|
|
||||||
if layer_conf is None:
|
if layer_conf is None:
|
||||||
if weight.device.type == "meta":
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
|
||||||
else:
|
|
||||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
|
||||||
else:
|
else:
|
||||||
self.quant_format = layer_conf.get("format", None)
|
self.quant_format = layer_conf.get("format", None)
|
||||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||||
@ -610,13 +604,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||||
|
|
||||||
if weight.device.type == "meta":
|
self.weight = torch.nn.Parameter(
|
||||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||||
else:
|
requires_grad=False
|
||||||
self.weight = torch.nn.Parameter(
|
)
|
||||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
|
||||||
requires_grad=False
|
|
||||||
)
|
|
||||||
|
|
||||||
for param_name in qconfig["parameters"]:
|
for param_name in qconfig["parameters"]:
|
||||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||||
@ -626,10 +617,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
_v = state_dict.pop(param_key, None)
|
_v = state_dict.pop(param_key, None)
|
||||||
if _v is None:
|
if _v is None:
|
||||||
continue
|
continue
|
||||||
if _v.device.type == "meta":
|
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
self.register_parameter(param_name, torch.nn.Parameter(_v, requires_grad=False))
|
|
||||||
else:
|
|
||||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
|
||||||
manually_loaded_keys.append(param_key)
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user