mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 21:00:16 +08:00
Integrate disk offload into memory management
This commit is contained in:
parent
557e4ee341
commit
97189bf6bb
@ -372,7 +372,11 @@ def _device_free_memory(device: torch.device) -> int:
|
|||||||
def _evict_ram_for_budget(required_bytes: int) -> int:
|
def _evict_ram_for_budget(required_bytes: int) -> int:
|
||||||
if required_bytes <= 0:
|
if required_bytes <= 0:
|
||||||
return 0
|
return 0
|
||||||
return evict_ram_cache(required_bytes)
|
freed = evict_ram_cache(required_bytes)
|
||||||
|
if freed < required_bytes:
|
||||||
|
from . import model_management
|
||||||
|
freed += model_management.evict_ram_to_disk(required_bytes - freed)
|
||||||
|
return freed
|
||||||
|
|
||||||
|
|
||||||
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
def _maybe_free_ram_budget(device: torch.device, required_bytes: int) -> int:
|
||||||
@ -654,6 +658,16 @@ def _find_tensor_dtype(args, kwargs) -> Optional[torch.dtype]:
|
|||||||
return check(kwargs)
|
return check(kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _select_weight_dtype(input_dtype: Optional[torch.dtype], manual_cast_dtype: Optional[torch.dtype]) -> Optional[torch.dtype]:
|
||||||
|
if manual_cast_dtype is not None:
|
||||||
|
return manual_cast_dtype
|
||||||
|
if input_dtype is None:
|
||||||
|
return None
|
||||||
|
if torch.is_floating_point(torch.empty((), dtype=input_dtype)):
|
||||||
|
return input_dtype
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ensure_module_materialized(
|
def ensure_module_materialized(
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
target_device: torch.device,
|
target_device: torch.device,
|
||||||
@ -744,7 +758,7 @@ def disk_weight_pre_hook(module: torch.nn.Module, args, kwargs):
|
|||||||
return
|
return
|
||||||
input_dtype = _find_tensor_dtype(args, kwargs)
|
input_dtype = _find_tensor_dtype(args, kwargs)
|
||||||
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
manual_cast_dtype = getattr(module, "manual_cast_dtype", None)
|
||||||
dtype_override = manual_cast_dtype or input_dtype
|
dtype_override = _select_weight_dtype(input_dtype, manual_cast_dtype)
|
||||||
if getattr(module, "comfy_cast_weights", False):
|
if getattr(module, "comfy_cast_weights", False):
|
||||||
target_device = torch.device("cpu")
|
target_device = torch.device("cpu")
|
||||||
fallback_device = _find_tensor_device(args, kwargs)
|
fallback_device = _find_tensor_device(args, kwargs)
|
||||||
@ -793,6 +807,15 @@ def _extract_to_device(args, kwargs) -> Optional[torch.device]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_to_dtype(args, kwargs) -> Optional[torch.dtype]:
|
||||||
|
if "dtype" in kwargs and kwargs["dtype"] is not None:
|
||||||
|
return kwargs["dtype"]
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.dtype):
|
||||||
|
return arg
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
||||||
for param in module.parameters(recurse=True):
|
for param in module.parameters(recurse=True):
|
||||||
if param is not None and param.device.type != "meta":
|
if param is not None and param.device.type != "meta":
|
||||||
@ -803,12 +826,58 @@ def _find_existing_device(module: torch.nn.Module) -> Optional[torch.device]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def move_module_tensors(module: torch.nn.Module, device_to: torch.device, dtype_override: Optional[torch.dtype] = None):
|
||||||
|
def _move(tensor):
|
||||||
|
if tensor is None:
|
||||||
|
return None
|
||||||
|
if tensor.device.type == "meta":
|
||||||
|
return tensor
|
||||||
|
if dtype_override is not None and tensor.dtype != dtype_override:
|
||||||
|
return tensor.to(device=device_to, dtype=dtype_override)
|
||||||
|
return tensor.to(device=device_to)
|
||||||
|
|
||||||
|
module._apply(_move)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def offload_module_weights(module: torch.nn.Module) -> int:
|
||||||
|
if not disk_weights_enabled():
|
||||||
|
return 0
|
||||||
|
refs = REGISTRY.get(module)
|
||||||
|
if not refs:
|
||||||
|
return 0
|
||||||
|
offloaded_bytes = 0
|
||||||
|
if module in LAZY_MODULE_STATE:
|
||||||
|
ref_name = next(iter(refs.keys()), None)
|
||||||
|
if ref_name is not None:
|
||||||
|
_evict_module_weight(module, ref_name, False)
|
||||||
|
for disk_ref in refs.values():
|
||||||
|
nbytes = _meta_nbytes(disk_ref.meta)
|
||||||
|
if nbytes is not None:
|
||||||
|
offloaded_bytes += nbytes
|
||||||
|
return offloaded_bytes
|
||||||
|
for name, disk_ref in refs.items():
|
||||||
|
_evict_module_weight(module, name, disk_ref.is_buffer)
|
||||||
|
nbytes = _meta_nbytes(disk_ref.meta)
|
||||||
|
if nbytes is not None:
|
||||||
|
offloaded_bytes += nbytes
|
||||||
|
return offloaded_bytes
|
||||||
|
|
||||||
|
|
||||||
def module_to(module: torch.nn.Module, *args, **kwargs):
|
def module_to(module: torch.nn.Module, *args, **kwargs):
|
||||||
|
allow_materialize = kwargs.pop("allow_materialize", True)
|
||||||
if disk_weights_enabled():
|
if disk_weights_enabled():
|
||||||
target_device = _extract_to_device(args, kwargs)
|
target_device = _extract_to_device(args, kwargs)
|
||||||
if target_device is None:
|
if target_device is None:
|
||||||
target_device = _find_existing_device(module) or torch.device("cpu")
|
target_device = _find_existing_device(module) or torch.device("cpu")
|
||||||
materialize_module_tree(module, target_device)
|
if target_device.type == "meta":
|
||||||
|
offload_module_weights(module)
|
||||||
|
return module
|
||||||
|
if allow_materialize:
|
||||||
|
materialize_module_tree(module, target_device)
|
||||||
|
return module.to(*args, **kwargs)
|
||||||
|
dtype_override = _extract_to_dtype(args, kwargs)
|
||||||
|
return move_module_tensors(module, target_device, dtype_override=dtype_override)
|
||||||
return module.to(*args, **kwargs)
|
return module.to(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,7 @@ import comfy.conds
|
|||||||
import comfy.ops
|
import comfy.ops
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from . import utils
|
from . import utils
|
||||||
|
from . import safetensors_stream
|
||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
import math
|
import math
|
||||||
@ -299,7 +300,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def load_model_weights(self, sd, unet_prefix=""):
|
def load_model_weights(self, sd, unet_prefix=""):
|
||||||
to_load = utils.state_dict_prefix_replace(sd, {unet_prefix: ""}, filter_keys=True)
|
replace_prefix = {unet_prefix: ""} if unet_prefix else {}
|
||||||
|
if replace_prefix:
|
||||||
|
if utils.is_stream_state_dict(sd):
|
||||||
|
to_load = utils.state_dict_prefix_replace(sd, replace_prefix, filter_keys=True)
|
||||||
|
else:
|
||||||
|
to_load = safetensors_stream.RenameViewStateDict(sd, replace_prefix, filter_keys=True, mutate_base=False)
|
||||||
|
else:
|
||||||
|
to_load = sd
|
||||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||||
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
m, u = utils.load_state_dict(self.diffusion_model, to_load, strict=False)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
|
|||||||
@ -530,7 +530,12 @@ class LoadedModel:
|
|||||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||||
if freed >= memory_to_free:
|
if freed >= memory_to_free:
|
||||||
return False
|
return False
|
||||||
self.model.detach(unpatch_weights)
|
offload_device = None
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
offload_device = torch.device("meta")
|
||||||
|
self.model.detach(unpatch_weights, offload_device=offload_device)
|
||||||
|
if offload_device is not None and offload_device.type == "meta":
|
||||||
|
logging.info(f"Unloaded {self.model.model.__class__.__name__} to disk")
|
||||||
self.model_finalizer.detach()
|
self.model_finalizer.detach()
|
||||||
self.model_finalizer = None
|
self.model_finalizer = None
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
@ -585,7 +590,9 @@ def minimum_inference_memory():
|
|||||||
def free_memory(memory_required, device, keep_loaded=[]):
|
def free_memory(memory_required, device, keep_loaded=[]):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
if is_device_cpu(device) and comfy.disk_weights.disk_weights_enabled():
|
||||||
comfy.disk_weights.evict_ram_cache(memory_required)
|
freed_cache = comfy.disk_weights.evict_ram_cache(memory_required)
|
||||||
|
if freed_cache < memory_required:
|
||||||
|
evict_ram_to_disk(memory_required - freed_cache)
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
unloaded_models = []
|
unloaded_models = []
|
||||||
@ -621,6 +628,34 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
|
|
||||||
|
def evict_ram_to_disk(memory_to_free, keep_loaded=[]):
|
||||||
|
if memory_to_free <= 0:
|
||||||
|
return 0
|
||||||
|
if not comfy.disk_weights.disk_weights_enabled():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
freed = 0
|
||||||
|
can_unload = []
|
||||||
|
for i in range(len(current_loaded_models) - 1, -1, -1):
|
||||||
|
shift_model = current_loaded_models[i]
|
||||||
|
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||||
|
loaded_memory = shift_model.model_loaded_memory()
|
||||||
|
if loaded_memory > 0:
|
||||||
|
can_unload.append((-loaded_memory, sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
|
|
||||||
|
for x in sorted(can_unload):
|
||||||
|
i = x[-1]
|
||||||
|
memory_needed = memory_to_free - freed
|
||||||
|
if memory_needed <= 0:
|
||||||
|
break
|
||||||
|
logging.debug(f"Offloading {current_loaded_models[i].model.model.__class__.__name__} to disk")
|
||||||
|
freed += current_loaded_models[i].model.partially_unload(torch.device("meta"), memory_needed)
|
||||||
|
|
||||||
|
if freed > 0:
|
||||||
|
logging.info("RAM evicted to disk: {:.2f} MB freed".format(freed / (1024 * 1024)))
|
||||||
|
return freed
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
@ -1293,7 +1328,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
if dev is None:
|
if dev is None:
|
||||||
dev = get_torch_device()
|
dev = get_torch_device()
|
||||||
|
|
||||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
if hasattr(dev, 'type') and dev.type == "meta":
|
||||||
|
mem_free_total = sys.maxsize
|
||||||
|
mem_free_torch = mem_free_total
|
||||||
|
elif hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||||
mem_free_total = psutil.virtual_memory().available
|
mem_free_total = psutil.virtual_memory().available
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -857,7 +857,7 @@ class ModelPatcher:
|
|||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
comfy.disk_weights.module_to(self.model, device_to)
|
comfy.disk_weights.module_to(self.model, device_to, allow_materialize=False)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
self.model.model_offload_buffer_memory = 0
|
self.model.model_offload_buffer_memory = 0
|
||||||
@ -917,7 +917,16 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
cast_weight = self.force_cast_weights
|
cast_weight = self.force_cast_weights
|
||||||
m.to(device_to)
|
freed_bytes = module_mem
|
||||||
|
if device_to is not None and device_to.type == "meta" and comfy.disk_weights.disk_weights_enabled():
|
||||||
|
freed_bytes = comfy.disk_weights.offload_module_weights(m)
|
||||||
|
if freed_bytes == 0:
|
||||||
|
freed_bytes = module_mem
|
||||||
|
else:
|
||||||
|
if comfy.disk_weights.disk_weights_enabled():
|
||||||
|
comfy.disk_weights.move_module_tensors(m, device_to)
|
||||||
|
else:
|
||||||
|
m.to(device_to)
|
||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
@ -940,7 +949,7 @@ class ModelPatcher:
|
|||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += freed_bytes
|
||||||
offload_buffer = max(offload_buffer, potential_offload)
|
offload_buffer = max(offload_buffer, potential_offload)
|
||||||
offload_weight_factor.append(module_mem)
|
offload_weight_factor.append(module_mem)
|
||||||
offload_weight_factor.pop(0)
|
offload_weight_factor.pop(0)
|
||||||
@ -954,7 +963,8 @@ class ModelPatcher:
|
|||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
self.model.model_offload_buffer_memory = offload_buffer
|
self.model.model_offload_buffer_memory = offload_buffer
|
||||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
target_label = "disk" if device_to is not None and device_to.type == "meta" else device_to
|
||||||
|
logging.info("Unloaded partially to {}: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(target_label, memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
@ -985,11 +995,12 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def detach(self, unpatch_all=True):
|
def detach(self, unpatch_all=True, offload_device=None):
|
||||||
self.eject_model()
|
self.eject_model()
|
||||||
self.model_patches_to(self.offload_device)
|
self.model_patches_to(self.offload_device)
|
||||||
|
target_device = self.offload_device if offload_device is None else offload_device
|
||||||
if unpatch_all:
|
if unpatch_all:
|
||||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
self.unpatch_model(target_device, unpatch_weights=unpatch_all)
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||||
callback(self, unpatch_all)
|
callback(self, unpatch_all)
|
||||||
return self.model
|
return self.model
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user