diff --git a/comfy_extras/nodes/nodes_group_offloading.py b/comfy_extras/nodes/nodes_group_offloading.py new file mode 100644 index 000000000..f60101268 --- /dev/null +++ b/comfy_extras/nodes/nodes_group_offloading.py @@ -0,0 +1,150 @@ +import torch +from diffusers import HookRegistry +from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook + +from comfy.model_management import vram_state, VRAMState +from comfy.model_patcher import ModelPatcher +from comfy.node_helpers import export_custom_nodes +from comfy.nodes.package_typing import CustomNode +from comfy.ops import manual_cast +from comfy.patcher_extension import WrappersMP +from comfy.rmsnorm import RMSNorm + +_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook" + + +class DisableComfyWeightCast(ModelHook): + r""" + A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype + for storage. This process may lead to quality loss in the output, but can significantly reduce the memory + footprint. + """ + + _is_stateful = False + + def __init__(self) -> None: + super().__init__() + + def initialize_hook(self, module: torch.nn.Module): + if hasattr(module, "comfy_cast_weights"): + module.comfy_cast_weights = False + return module + + def deinitalize_hook(self, module: torch.nn.Module): + if hasattr(module, "comfy_cast_weights"): + module.comfy_cast_weights = True + return module + + +def disable_comfyui_weight_casting_hook(module: torch.nn.Module): + registry = HookRegistry.check_if_exists_or_initialize(module) + hook = DisableComfyWeightCast() + registry.register_hook(hook, _DISABLE_COMFYUI_CASTING_HOOK) + + +def disable_comfyui_weight_casting(module: torch.nn.Module): + if isinstance(module, ( + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, + torch.nn.GroupNorm, + torch.nn.LayerNorm, + torch.nn.RMSNorm, + RMSNorm, + torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose1d, + torch.nn.Embedding + )): + disable_comfyui_weight_casting_hook(module) + return + + for name, submodule in module.named_children(): + disable_comfyui_weight_casting(submodule) + + +def prepare_group_offloading_factory(load_device: torch.device, offload_device: torch.device): + def wrapper(executor, model: ModelPatcher, *args, **kwargs): + # this model will now just be loaded to CPU, since diffusers will manage moving to gpu + model.load_device = offload_device + # loads the model, prepares everything + inner_model, conds, models = executor(model, *args, **kwargs) + + # we will need layer casting from diffusers in this situation + if model.model.operations == manual_cast and model.diffusion_model.dtype != model.model.manual_cast_dtype: + raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported") + + # weights are patched, ready to go, inner model will be correctly deleted at the end of sampling + apply_group_offloading( + inner_model.diffusion_model, + load_device, + offload_device, + use_stream=True, + record_stream=True, + low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,), + num_blocks_per_group=1 + ) + # then the inputs will be ready on the correct device due to the wrapper factory + model.load_device = load_device + return inner_model, conds, models + + return wrapper + + +def prepare_layerwise_casting_factory(dtype: torch.dtype): + def wrapper(executor, model: ModelPatcher, *args, **kwargs): + disable_comfyui_weight_casting(model.diffusion_model) + apply_layerwise_casting(model.diffusion_model, + dtype, + model.diffusion_model.dtype, + non_blocking=True) + inner_model, conds, models = executor(model, *args, **kwargs) + + return inner_model, conds, models + + return wrapper + + +class GroupOffload(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL", {}) + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "execute" + + def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]: + model = model.clone() + model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device)) + return model, + + +class LayerwiseCast(CustomNode): + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "model": ("MODEL", {}), + "dtype": (["float8_e4m3fn", "float8_e5m2"], {}) + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "execute" + + def execute(self, model: ModelPatcher, dtype: str) -> tuple[ModelPatcher,]: + model = model.clone() + if dtype == "float8_e4m3fn": + dtype = torch.float8_e4m3fn + elif dtype == "float8_e5m2": + dtype = torch.float8_e5m2 + + model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_layerwise_casting_factory(dtype)) + return model, + + +export_custom_nodes()