ComfyUI/comfy_extras/nodes/nodes_torch_compile.py
doctorpangloss bbe2ed330c Memory management and compilation improvements
- Experimental support for sage attention on Linux
 - Diffusers loader now supports model indices
 - Transformers model management now aligns with updates to ComfyUI
 - Flux layers correctly use unbind
 - Add float8 support for model loading in more places
 - Experimental quantization approaches from Quanto and torchao
 - Model upscaling interacts with memory management better

This update also disables ROCm testing because it isn't reliable enough
on consumer hardware. ROCm is not really supported by the 7600.
2024-10-09 09:13:47 -07:00

115 lines
4.0 KiB
Python

import logging
import torch
from torch.nn import LayerNorm
from comfy import model_management
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
},
"optional": {
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": ("STRING", {"default": "inductor"}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
"fullgraph": fullgraph,
"dynamic": dynamic,
"backend": backend
}
if isinstance(model, ModelPatcher):
m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
return (m,)
elif isinstance(model, torch.nn.Module):
return torch.compile(model=model, **compile_kwargs),
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
class QuantizeModel(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
"strategy": (["torchao", "quanto"], {"default": "torchao"})
}
}
FUNCTION = "execute"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
RETURN_TYPES = ("MODEL",)
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# default exclusions
_unused_exclusions = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}
if strategy == "quanto":
from optimum.quanto import quantize, qint8
exclusion_list = [
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
]
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet
elif strategy == "torchao":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# def filter_fn(module: torch.nn.Module, name: str):
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
_in_place_fixme = unet
else:
raise ValueError(f"unknown strategy {strategy}")
model.add_object_patch("diffusion_model", _in_place_fixme)
return model,
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
"QuantizeModel": QuantizeModel,
}