ComfyUI/comfy_extras/nodes/nodes_group_offloading.py
2025-10-21 14:27:26 -07:00

188 lines
6.7 KiB
Python

import torch
import logging
from diffusers import HookRegistry
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_management import vram_state, VRAMState, unload_all_models, get_free_memory, get_torch_device
from comfy.model_management_types import HooksSupport, ModelManageable
from comfy.model_patcher import ModelPatcher
from comfy.node_helpers import export_custom_nodes
from comfy.nodes.package_typing import CustomNode
from comfy.ops import manual_cast
from comfy.patcher_extension import WrappersMP
from comfy.rmsnorm import RMSNorm
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
logger = logging.getLogger(__name__)
class DisableComfyWeightCast(ModelHook):
r"""
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
footprint.
"""
_is_stateful = False
def __init__(self) -> None:
super().__init__()
def initialize_hook(self, module: torch.nn.Module):
if hasattr(module, "comfy_cast_weights"):
module.comfy_cast_weights = False
return module
def deinitalize_hook(self, module: torch.nn.Module):
if hasattr(module, "comfy_cast_weights"):
module.comfy_cast_weights = True
return module
def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = DisableComfyWeightCast()
registry.register_hook(hook, _DISABLE_COMFYUI_CASTING_HOOK)
def disable_comfyui_weight_casting(module: torch.nn.Module):
types = [
torch.nn.Linear,
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Conv3d,
torch.nn.GroupNorm,
torch.nn.LayerNorm,
RMSNorm,
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose1d,
torch.nn.Embedding
]
try:
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
types.append(TorchRMSNorm)
except (ImportError, ModuleNotFoundError):
pass
if isinstance(module, tuple(types)):
disable_comfyui_weight_casting_hook(module)
return
for name, submodule in module.named_children():
disable_comfyui_weight_casting(submodule)
def prepare_group_offloading_factory(load_device: torch.device, offload_device: torch.device):
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu
model.load_device = offload_device
# we'll have to unload everything to use pinning better, this includes trimming
unload_all_models()
# loads the model, prepares everything
inner_model, conds, models = executor(model, *args, **kwargs)
# we will need layer casting from diffusers in this situation
if model.model.operations == manual_cast and model.diffusion_model.dtype != model.model.manual_cast_dtype:
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
model_size = model.model_size()
model_too_large = model_size * 2 > get_free_memory(torch.cpu)
low_vram_state = vram_state in (VRAMState.LOW_VRAM,)
is_cuda_device = load_device.type == 'cuda'
if model_too_large or low_vram_state:
logger.error(f"group offloading did not use memory pinning because model_too_large={model_too_large} low_vram_state={low_vram_state}")
if not is_cuda_device:
logger.error(f"group offloading did not use stream because load_device.type={load_device.type} != \"cuda\"")
apply_group_offloading(
inner_model.diffusion_model,
load_device,
offload_device,
use_stream=is_cuda_device,
record_stream=is_cuda_device,
low_cpu_mem_usage=low_vram_state or model_too_large,
num_blocks_per_group=1
)
# then the inputs will be ready on the correct device due to the wrapper factory
model.load_device = load_device
return inner_model, conds, models
return wrapper
def prepare_layerwise_casting_factory(dtype: torch.dtype):
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
disable_comfyui_weight_casting(model.diffusion_model)
apply_layerwise_casting(model.diffusion_model,
dtype,
model.diffusion_model.dtype,
non_blocking=True)
inner_model, conds, models = executor(model, *args, **kwargs)
return inner_model, conds, models
return wrapper
class GroupOffload(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL", {})
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "execute"
def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
if isinstance(model, ModelManageable):
model = model.clone()
if isinstance(model, TransformersManagedModel):
apply_group_offloading(
model.model,
model.load_device,
model.offload_device,
use_stream=True,
record_stream=True,
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
num_blocks_per_group=1
)
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
model.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING, "group_offload", prepare_group_offloading_factory(model.load_device, model.offload_device))
return model,
class LayerwiseCast(CustomNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL", {}),
"dtype": (["float8_e4m3fn", "float8_e5m2"], {})
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "execute"
def execute(self, model: ModelPatcher, dtype: str) -> tuple[ModelPatcher,]:
model = model.clone()
if dtype == "float8_e4m3fn":
dtype = torch.float8_e4m3fn
elif dtype == "float8_e5m2":
dtype = torch.float8_e5m2
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_layerwise_casting_factory(dtype))
return model,
export_custom_nodes()