mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
Group offloading support
This commit is contained in:
parent
7179ed44bd
commit
112017f2d2
150
comfy_extras/nodes/nodes_group_offloading.py
Normal file
150
comfy_extras/nodes/nodes_group_offloading.py
Normal file
@ -0,0 +1,150 @@
|
||||
import torch
|
||||
from diffusers import HookRegistry
|
||||
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
||||
|
||||
from comfy.model_management import vram_state, VRAMState
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.node_helpers import export_custom_nodes
|
||||
from comfy.nodes.package_typing import CustomNode
|
||||
from comfy.ops import manual_cast
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from comfy.rmsnorm import RMSNorm
|
||||
|
||||
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
|
||||
|
||||
|
||||
class DisableComfyWeightCast(ModelHook):
|
||||
r"""
|
||||
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
||||
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
||||
footprint.
|
||||
"""
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module):
|
||||
if hasattr(module, "comfy_cast_weights"):
|
||||
module.comfy_cast_weights = False
|
||||
return module
|
||||
|
||||
def deinitalize_hook(self, module: torch.nn.Module):
|
||||
if hasattr(module, "comfy_cast_weights"):
|
||||
module.comfy_cast_weights = True
|
||||
return module
|
||||
|
||||
|
||||
def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
hook = DisableComfyWeightCast()
|
||||
registry.register_hook(hook, _DISABLE_COMFYUI_CASTING_HOOK)
|
||||
|
||||
|
||||
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
||||
if isinstance(module, (
|
||||
torch.nn.Linear,
|
||||
torch.nn.Conv1d,
|
||||
torch.nn.Conv2d,
|
||||
torch.nn.Conv3d,
|
||||
torch.nn.GroupNorm,
|
||||
torch.nn.LayerNorm,
|
||||
torch.nn.RMSNorm,
|
||||
RMSNorm,
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose1d,
|
||||
torch.nn.Embedding
|
||||
)):
|
||||
disable_comfyui_weight_casting_hook(module)
|
||||
return
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
disable_comfyui_weight_casting(submodule)
|
||||
|
||||
|
||||
def prepare_group_offloading_factory(load_device: torch.device, offload_device: torch.device):
|
||||
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
|
||||
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu
|
||||
model.load_device = offload_device
|
||||
# loads the model, prepares everything
|
||||
inner_model, conds, models = executor(model, *args, **kwargs)
|
||||
|
||||
# we will need layer casting from diffusers in this situation
|
||||
if model.model.operations == manual_cast and model.diffusion_model.dtype != model.model.manual_cast_dtype:
|
||||
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
|
||||
|
||||
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
|
||||
apply_group_offloading(
|
||||
inner_model.diffusion_model,
|
||||
load_device,
|
||||
offload_device,
|
||||
use_stream=True,
|
||||
record_stream=True,
|
||||
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
|
||||
num_blocks_per_group=1
|
||||
)
|
||||
# then the inputs will be ready on the correct device due to the wrapper factory
|
||||
model.load_device = load_device
|
||||
return inner_model, conds, models
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def prepare_layerwise_casting_factory(dtype: torch.dtype):
|
||||
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
|
||||
disable_comfyui_weight_casting(model.diffusion_model)
|
||||
apply_layerwise_casting(model.diffusion_model,
|
||||
dtype,
|
||||
model.diffusion_model.dtype,
|
||||
non_blocking=True)
|
||||
inner_model, conds, models = executor(model, *args, **kwargs)
|
||||
|
||||
return inner_model, conds, models
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class GroupOffload(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: ModelPatcher) -> tuple[ModelPatcher,]:
|
||||
model = model.clone()
|
||||
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_group_offloading_factory(model.load_device, model.offload_device))
|
||||
return model,
|
||||
|
||||
|
||||
class LayerwiseCast(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {}),
|
||||
"dtype": (["float8_e4m3fn", "float8_e5m2"], {})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, model: ModelPatcher, dtype: str) -> tuple[ModelPatcher,]:
|
||||
model = model.clone()
|
||||
if dtype == "float8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
elif dtype == "float8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
|
||||
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_layerwise_casting_factory(dtype))
|
||||
return model,
|
||||
|
||||
|
||||
export_custom_nodes()
|
||||
Loading…
Reference in New Issue
Block a user