mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Integrate disk offload into memory management
This commit is contained in:
parent
557e4ee341
commit
97189bf6bb
@ -372,7 +372,11 @@ def _device_free_memory(device: torch.device) -> int:
|
||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
||||
if required_bytes <= 0:
|
||||
return 0
|
||||
return evict_ram_cache(required_bytes)
|
||||
freed = evict_ram_cache(required_bytes)
|
||||
if freed < required_bytes:
|
||||
from . import model_management
|
||||
freed += model_management.evict_ram_to_disk(required_bytes - freed)
|
||||
return freed
|
||||
|
||||
|
||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
||||
@ -654,6 +658,16 @@ def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||
return check(kwargs)
|
||||
|
||||
|
||||
def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]:
|
||||
if manual_cast_dtype is not None:
|
||||
return manual_cast_dtype
|
||||
if input_dtype is None:
|
||||
return None
|
||||
if torch.is_floating_point(torch.empty((), dtype=input_dtype)):
|
||||
return input_dtype
|
||||
return None
|
||||
|
||||
|
||||
def ensure_module_materialized(
|
||||
module: torch.nn.Module,
|
||||
target_device: torch.device,
|
||||
@ -744,7 +758,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
||||
return
|
||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||
dtype_override = manual_cast_dtype or input_dtype
|
||||
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||
if getattr(module, "comfy_cast_weights", False):
|
||||
target_device = torch.device("cpu")
|
||||
fallback_device = _find_tensor_device(args, kwargs)
|
||||
@ -793,6 +807,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
||||
return kwargs["dtype"]
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.dtype):
|
||||
return arg
|
||||
return None
|
||||
|
||||
|
||||
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
for param in module.parameters(recurse=True):
|
||||
if param is not None and param.device.type != "meta":
|
||||
@ -803,12 +826,58 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||
return None
|
||||
|
||||
|
||||
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||
def _move(tensor):
|
||||
if tensor is None:
|
||||
return None
|
||||
if tensor.device.type == "meta":
|
||||
return tensor
|
||||
if dtype_override is not None and tensor.dtype != dtype_override:
|
||||
return tensor.to(device=device_to, dtype=dtype_override)
|
||||
return tensor.to(device=device_to)
|
||||
|
||||
module._apply(_move)
|
||||
return module
|
||||
|
||||
|
||||
def offload_module_weights(module: torch.nn.Module) -> int:
|
||||
if not disk_weights_enabled():
|
||||
return 0
|
||||
refs = REGISTRY.get(module)
|
||||
if not refs:
|
||||
return 0
|
||||
offloaded_bytes = 0
|
||||
if module in LAZY_MODULE_STATE:
|
||||
ref_name = next(iter(refs.keys()), None)
|
||||
if ref_name is not None:
|
||||
_evict_module_weight(module, ref_name, False)
|
||||
for disk_ref in refs.values():
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
offloaded_bytes += nbytes
|
||||
return offloaded_bytes
|
||||
for name, disk_ref in refs.items():
|
||||
_evict_module_weight(module, name, disk_ref.is_buffer)
|
||||
nbytes = _meta_nbytes(disk_ref.meta)
|
||||
if nbytes is not None:
|
||||
offloaded_bytes += nbytes
|
||||
return offloaded_bytes
|
||||
|
||||
|
||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||
allow_materialize = kwargs.pop("allow_materialize", True)
|
||||
if disk_weights_enabled():
|
||||
target_device = _extract_to_device(args, kwargs)
|
||||
if target_device is None:
|
||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||
materialize_module_tree(module, target_device)
|
||||
if target_device.type == "meta":
|
||||
offload_module_weights(module)
|
||||
return module
|
||||
if allow_materialize:
|
||||
materialize_module_tree(module, target_device)
|
||||
return module.to(*args, **kwargs)
|
||||
dtype_override = _extract_to_dtype(args, kwargs)
|
||||
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||
return module.to(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ import comfy.conds
|
||||
import comfy.ops
|
||||
from enum import Enum
|
||||
from . import utils
|
||||
from . import safetensors_stream
|
||||
import comfy.latent_formats
|
||||
import comfy.model_sampling
|
||||
import math
|
||||
@ -299,7 +300,14 @@ class BaseModel(torch.nn.Module):
|
||||
return out
|
||||
|
||||
def load_model_weights(self, sd, unet_prefix=""):
|
||||
to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
|
||||
replace_prefix = {unet_prefix: ""} if unet_prefix else {}
|
||||
if replace_prefix:
|
||||
if utils.is_stream_state_dict(sd):
|
||||
to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
|
||||
else:
|
||||
to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
|
||||
else:
|
||||
to_load = sd
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
|
||||
@ -530,7 +530,12 @@ class LoadedModel:
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
offload_device = None
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
offload_device = torch.device("meta")
|
||||
self.model.detach(unpatch_weights, offload_device=offload_device)
|
||||
if offload_device is not None and offload_device.type == "meta":
|
||||
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -585,7 +590,9 @@ def minimum_inference_memory():
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||
if freed_cache < memory_required:
|
||||
evict_ram_to_disk(memory_required - freed_cache)
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
unloaded_models = []
|
||||
@ -621,6 +628,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
|
||||
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||
if memory_to_free <= 0:
|
||||
return 0
|
||||
if not comfy.disk_weights.disk_weights_enabled():
|
||||
return 0
|
||||
|
||||
freed = 0
|
||||
can_unload = []
|
||||
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
loaded_memory = shift_model.model_loaded_memory()
|
||||
if loaded_memory > 0:
|
||||
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
memory_needed = memory_to_free - freed
|
||||
if memory_needed <= 0:
|
||||
break
|
||||
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
||||
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||
|
||||
if freed > 0:
|
||||
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
||||
return freed
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
@ -1293,7 +1328,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
if hasattr(dev, 'type') and dev.type == "meta":
|
||||
mem_free_total = sys.maxsize
|
||||
mem_free_torch = mem_free_total
|
||||
elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
mem_free_torch = mem_free_total
|
||||
else:
|
||||
|
||||
@ -857,7 +857,7 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
comfy.disk_weights.module_to(self.model, device_to)
|
||||
comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
@ -917,7 +917,16 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
freed_bytes = module_mem
|
||||
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||
if freed_bytes == 0:
|
||||
freed_bytes = module_mem
|
||||
else:
|
||||
if comfy.disk_weights.disk_weights_enabled():
|
||||
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||
else:
|
||||
m.to(device_to)
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
@ -940,7 +949,7 @@ class ModelPatcher:
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
memory_freed += freed_bytes
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
offload_weight_factor.append(module_mem)
|
||||
offload_weight_factor.pop(0)
|
||||
@ -954,7 +963,8 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
|
||||
logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@ -985,11 +995,12 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def detach(self, unpatch_all=True):
|
||||
def detach(self, unpatch_all=True, offload_device=None):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
target_device = self.offload_device if offload_device is None else offload_device
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
self.unpatch_model(target_device, unpatch_weights=unpatch_all)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||
callback(self, unpatch_all)
|
||||
return self.model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user