Integrate disk offload into memory management

This commit is contained in:
ifilipis 2026-01-09 00:38:37 +02:00
parent 557e4ee341
commit 97189bf6bb
4 changed files with 139 additions and 13 deletions

View File

@ -372,7 +372,11 @@ def _device_free_memory(device: torch.device) -> int:
def _evict_ram_for_budget(required_bytes: int) -> int:
if required_bytes <= 0:
return 0
return evict_ram_cache(required_bytes)
freed = evict_ram_cache(required_bytes)
if freed < required_bytes:
from . import model_management
freed += model_management.evict_ram_to_disk(required_bytes - freed)
return freed
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
@ -654,6 +658,16 @@ def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
return check(kwargs)
def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]:
if manual_cast_dtype is not None:
return manual_cast_dtype
if input_dtype is None:
return None
if torch.is_floating_point(torch.empty((), dtype=input_dtype)):
return input_dtype
return None
def ensure_module_materialized(
module: torch.nn.Module,
target_device: torch.device,
@ -744,7 +758,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
return
input_dtype = _find_tensor_dtype(args, kwargs)
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
dtype_override = manual_cast_dtype or input_dtype
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
if getattr(module, "comfy_cast_weights", False):
target_device = torch.device("cpu")
fallback_device = _find_tensor_device(args, kwargs)
@ -793,6 +807,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]:
return None
def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
if "dtype" in kwargs and kwargs["dtype"] is not None:
return kwargs["dtype"]
for arg in args:
if isinstance(arg, torch.dtype):
return arg
return None
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
for param in module.parameters(recurse=True):
if param is not None and param.device.type != "meta":
@ -803,12 +826,58 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
return None
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
def _move(tensor):
if tensor is None:
return None
if tensor.device.type == "meta":
return tensor
if dtype_override is not None and tensor.dtype != dtype_override:
return tensor.to(device=device_to, dtype=dtype_override)
return tensor.to(device=device_to)
module._apply(_move)
return module
def offload_module_weights(module: torch.nn.Module) -> int:
if not disk_weights_enabled():
return 0
refs = REGISTRY.get(module)
if not refs:
return 0
offloaded_bytes = 0
if module in LAZY_MODULE_STATE:
ref_name = next(iter(refs.keys()), None)
if ref_name is not None:
_evict_module_weight(module, ref_name, False)
for disk_ref in refs.values():
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
offloaded_bytes += nbytes
return offloaded_bytes
for name, disk_ref in refs.items():
_evict_module_weight(module, name, disk_ref.is_buffer)
nbytes = _meta_nbytes(disk_ref.meta)
if nbytes is not None:
offloaded_bytes += nbytes
return offloaded_bytes
def module_to(module: torch.nn.Module, *args, **kwargs):
allow_materialize = kwargs.pop("allow_materialize", True)
if disk_weights_enabled():
target_device = _extract_to_device(args, kwargs)
if target_device is None:
target_device = _find_existing_device(module) or torch.device("cpu")
materialize_module_tree(module, target_device)
if target_device.type == "meta":
offload_module_weights(module)
return module
if allow_materialize:
materialize_module_tree(module, target_device)
return module.to(*args, **kwargs)
dtype_override = _extract_to_dtype(args, kwargs)
return move_module_tensors(module, target_device, dtype_override=dtype_override)
return module.to(*args, **kwargs)

View File

@ -56,6 +56,7 @@ import comfy.conds
import comfy.ops
from enum import Enum
from . import utils
from . import safetensors_stream
import comfy.latent_formats
import comfy.model_sampling
import math
@ -299,7 +300,14 @@ class BaseModel(torch.nn.Module):
return out
def load_model_weights(self, sd, unet_prefix=""):
to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
replace_prefix = {unet_prefix: ""} if unet_prefix else {}
if replace_prefix:
if utils.is_stream_state_dict(sd):
to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
else:
to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
else:
to_load = sd
to_load = self.model_config.process_unet_state_dict(to_load)
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
if len(m) > 0:

View File

@ -530,7 +530,12 @@ class LoadedModel:
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
if freed >= memory_to_free:
return False
self.model.detach(unpatch_weights)
offload_device = None
if comfy.disk_weights.disk_weights_enabled():
offload_device = torch.device("meta")
self.model.detach(unpatch_weights, offload_device=offload_device)
if offload_device is not None and offload_device.type == "meta":
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
self.model_finalizer.detach()
self.model_finalizer = None
self.real_model = None
@ -585,7 +590,9 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[]):
cleanup_models_gc()
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.evict_ram_cache(memory_required)
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
if freed_cache < memory_required:
evict_ram_to_disk(memory_required - freed_cache)
unloaded_model = []
can_unload = []
unloaded_models = []
@ -621,6 +628,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
soft_empty_cache()
return unloaded_models
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
if memory_to_free <= 0:
return 0
if not comfy.disk_weights.disk_weights_enabled():
return 0
freed = 0
can_unload = []
for i in range(len(current_loaded_models) - 1, -1, -1):
shift_model = current_loaded_models[i]
if shift_model not in keep_loaded and not shift_model.is_dead():
loaded_memory = shift_model.model_loaded_memory()
if loaded_memory > 0:
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
for x in sorted(can_unload):
i = x[-1]
memory_needed = memory_to_free - freed
if memory_needed <= 0:
break
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
if freed > 0:
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
return freed
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@ -1293,7 +1328,10 @@ def get_free_memory(dev=None, torch_free_too=False):
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
if hasattr(dev, 'type') and dev.type == "meta":
mem_free_total = sys.maxsize
mem_free_torch = mem_free_total
elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:

View File

@ -857,7 +857,7 @@ class ModelPatcher:
self.backup.clear()
if device_to is not None:
comfy.disk_weights.module_to(self.model, device_to)
comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0
@ -917,7 +917,16 @@ class ModelPatcher:
bias_key = "{}.bias".format(n)
if move_weight:
cast_weight = self.force_cast_weights
m.to(device_to)
freed_bytes = module_mem
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
freed_bytes = comfy.disk_weights.offload_module_weights(m)
if freed_bytes == 0:
freed_bytes = module_mem
else:
if comfy.disk_weights.disk_weights_enabled():
comfy.disk_weights.move_module_tensors(m, device_to)
else:
m.to(device_to)
module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
@ -940,7 +949,7 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
memory_freed += freed_bytes
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
@ -954,7 +963,8 @@ class ModelPatcher:
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@ -985,11 +995,12 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used
def detach(self, unpatch_all=True):
def detach(self, unpatch_all=True, offload_device=None):
self.eject_model()
self.model_patches_to(self.offload_device)
target_device = self.offload_device if offload_device is None else offload_device
if unpatch_all:
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
self.unpatch_model(target_device, unpatch_weights=unpatch_all)
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
callback(self, unpatch_all)
return self.model