mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
188 lines
6.7 KiB
Python
188 lines
6.7 KiB
Python
import torch
|
|
import logging
|
|
from diffusers import HookRegistry
|
|
from diffusers.hooks import apply_group_offloading, apply_layerwise_casting, ModelHook
|
|
|
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
|
from comfy.model_management import vram_state, VRAMState, unload_all_models, get_free_memory, get_torch_device
|
|
from comfy.model_management_types import HooksSupport, ModelManageable
|
|
from comfy.model_patcher import ModelPatcher
|
|
from comfy.node_helpers import export_custom_nodes
|
|
from comfy.nodes.package_typing import CustomNode
|
|
from comfy.ops import manual_cast
|
|
from comfy.patcher_extension import WrappersMP
|
|
from comfy.rmsnorm import RMSNorm
|
|
|
|
_DISABLE_COMFYUI_CASTING_HOOK = "disable_comfyui_casting_hook"
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DisableComfyWeightCast(ModelHook):
|
|
r"""
|
|
A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
|
|
for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
|
|
footprint.
|
|
"""
|
|
|
|
_is_stateful = False
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def initialize_hook(self, module: torch.nn.Module):
|
|
if hasattr(module, "comfy_cast_weights"):
|
|
module.comfy_cast_weights = False
|
|
return module
|
|
|
|
def deinitalize_hook(self, module: torch.nn.Module):
|
|
if hasattr(module, "comfy_cast_weights"):
|
|
module.comfy_cast_weights = True
|
|
return module
|
|
|
|
|
|
def disable_comfyui_weight_casting_hook(module: torch.nn.Module):
|
|
registry = HookRegistry.check_if_exists_or_initialize(module)
|
|
hook = DisableComfyWeightCast()
|
|
registry.register_hook(hook, _DISABLE_COMFYUI_CASTING_HOOK)
|
|
|
|
|
|
def disable_comfyui_weight_casting(module: torch.nn.Module):
|
|
types = [
|
|
torch.nn.Linear,
|
|
torch.nn.Conv1d,
|
|
torch.nn.Conv2d,
|
|
torch.nn.Conv3d,
|
|
torch.nn.GroupNorm,
|
|
torch.nn.LayerNorm,
|
|
RMSNorm,
|
|
torch.nn.ConvTranspose2d,
|
|
torch.nn.ConvTranspose1d,
|
|
torch.nn.Embedding
|
|
]
|
|
try:
|
|
from torch.nn import RMSNorm as TorchRMSNorm # pylint: disable=no-member
|
|
types.append(TorchRMSNorm)
|
|
except (ImportError, ModuleNotFoundError):
|
|
pass
|
|
|
|
if isinstance(module, tuple(types)):
|
|
disable_comfyui_weight_casting_hook(module)
|
|
return
|
|
|
|
for name, submodule in module.named_children():
|
|
disable_comfyui_weight_casting(submodule)
|
|
|
|
|
|
def prepare_group_offloading_factory(load_device: torch.device, offload_device: torch.device):
|
|
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
|
|
# this model will now just be loaded to CPU, since diffusers will manage moving to gpu
|
|
model.load_device = offload_device
|
|
|
|
# we'll have to unload everything to use pinning better, this includes trimming
|
|
unload_all_models()
|
|
|
|
# loads the model, prepares everything
|
|
inner_model, conds, models = executor(model, *args, **kwargs)
|
|
|
|
# we will need layer casting from diffusers in this situation
|
|
if model.model.operations == manual_cast and model.diffusion_model.dtype != model.model.manual_cast_dtype:
|
|
raise ValueError("manual casting operations, where the model is loaded in different weights than inference will occur, is not supported")
|
|
|
|
# weights are patched, ready to go, inner model will be correctly deleted at the end of sampling
|
|
model_size = model.model_size()
|
|
|
|
model_too_large = model_size * 2 > get_free_memory(torch.cpu)
|
|
low_vram_state = vram_state in (VRAMState.LOW_VRAM,)
|
|
is_cuda_device = load_device.type == 'cuda'
|
|
|
|
if model_too_large or low_vram_state:
|
|
logger.error(f"group offloading did not use memory pinning because model_too_large={model_too_large} low_vram_state={low_vram_state}")
|
|
if not is_cuda_device:
|
|
logger.error(f"group offloading did not use stream because load_device.type={load_device.type} != \"cuda\"")
|
|
apply_group_offloading(
|
|
inner_model.diffusion_model,
|
|
load_device,
|
|
offload_device,
|
|
use_stream=is_cuda_device,
|
|
record_stream=is_cuda_device,
|
|
low_cpu_mem_usage=low_vram_state or model_too_large,
|
|
num_blocks_per_group=1
|
|
)
|
|
# then the inputs will be ready on the correct device due to the wrapper factory
|
|
model.load_device = load_device
|
|
return inner_model, conds, models
|
|
|
|
return wrapper
|
|
|
|
|
|
def prepare_layerwise_casting_factory(dtype: torch.dtype):
|
|
def wrapper(executor, model: ModelPatcher, *args, **kwargs):
|
|
disable_comfyui_weight_casting(model.diffusion_model)
|
|
apply_layerwise_casting(model.diffusion_model,
|
|
dtype,
|
|
model.diffusion_model.dtype,
|
|
non_blocking=True)
|
|
inner_model, conds, models = executor(model, *args, **kwargs)
|
|
|
|
return inner_model, conds, models
|
|
|
|
return wrapper
|
|
|
|
|
|
class GroupOffload(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: ModelManageable | HooksSupport | TransformersManagedModel) -> tuple[ModelPatcher,]:
|
|
if isinstance(model, ModelManageable):
|
|
model = model.clone()
|
|
if isinstance(model, TransformersManagedModel):
|
|
apply_group_offloading(
|
|
model.model,
|
|
model.load_device,
|
|
model.offload_device,
|
|
use_stream=True,
|
|
record_stream=True,
|
|
low_cpu_mem_usage=vram_state in (VRAMState.LOW_VRAM,),
|
|
num_blocks_per_group=1
|
|
)
|
|
elif isinstance(model, HooksSupport) and isinstance(model, ModelManageable):
|
|
model.add_wrapper_with_key(WrappersMP.PREPARE_SAMPLING, "group_offload", prepare_group_offloading_factory(model.load_device, model.offload_device))
|
|
return model,
|
|
|
|
|
|
class LayerwiseCast(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {}),
|
|
"dtype": (["float8_e4m3fn", "float8_e5m2"], {})
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "execute"
|
|
|
|
def execute(self, model: ModelPatcher, dtype: str) -> tuple[ModelPatcher,]:
|
|
model = model.clone()
|
|
if dtype == "float8_e4m3fn":
|
|
dtype = torch.float8_e4m3fn
|
|
elif dtype == "float8_e5m2":
|
|
dtype = torch.float8_e5m2
|
|
|
|
model.add_wrapper(WrappersMP.PREPARE_SAMPLING, prepare_layerwise_casting_factory(dtype))
|
|
return model,
|
|
|
|
|
|
export_custom_nodes()
|