mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
import logging
|
|
|
|
import torch
|
|
from torch.nn import LayerNorm
|
|
|
|
from comfy import model_management
|
|
from comfy.model_patcher import ModelPatcher
|
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
|
|
|
DIFFUSION_MODEL = "diffusion_model"
|
|
|
|
|
|
class TorchCompileModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
},
|
|
"optional": {
|
|
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
|
"fullgraph": ("BOOLEAN", {"default": False}),
|
|
"dynamic": ("BOOLEAN", {"default": False}),
|
|
"backend": ("STRING", {"default": "inductor"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "_for_testing"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
|
if object_patch is None:
|
|
object_patch = DIFFUSION_MODEL
|
|
compile_kwargs = {
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
"backend": backend
|
|
}
|
|
if isinstance(model, ModelPatcher):
|
|
m = model.clone()
|
|
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
|
return (m,)
|
|
elif isinstance(model, torch.nn.Module):
|
|
return torch.compile(model=model, **compile_kwargs),
|
|
else:
|
|
logging.warning("Encountered a model that cannot be compiled")
|
|
return model,
|
|
|
|
|
|
class QuantizeModel(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {}),
|
|
"strategy": (["torchao", "quanto"], {"default": "torchao"})
|
|
}
|
|
}
|
|
|
|
FUNCTION = "execute"
|
|
CATEGORY = "_for_testing"
|
|
EXPERIMENTAL = True
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
|
|
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
|
|
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
|
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
|
model = model.clone()
|
|
unet = model.get_model_object("diffusion_model")
|
|
# todo: quantize quantizes in place, which is not desired
|
|
|
|
# default exclusions
|
|
_unused_exclusions = {
|
|
"time_embedding.",
|
|
"add_embedding.",
|
|
"time_in.",
|
|
"txt_in.",
|
|
"vector_in.",
|
|
"img_in.",
|
|
"guidance_in.",
|
|
"final_layer.",
|
|
}
|
|
if strategy == "quanto":
|
|
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
|
exclusion_list = [
|
|
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
|
]
|
|
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
|
_in_place_fixme = unet
|
|
elif strategy == "torchao":
|
|
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error
|
|
model = model.clone()
|
|
unet = model.get_model_object("diffusion_model")
|
|
# todo: quantize quantizes in place, which is not desired
|
|
|
|
# def filter_fn(module: torch.nn.Module, name: str):
|
|
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
|
|
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
|
_in_place_fixme = unet
|
|
else:
|
|
raise ValueError(f"unknown strategy {strategy}")
|
|
|
|
model.add_object_patch("diffusion_model", _in_place_fixme)
|
|
return model,
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"TorchCompileModel": TorchCompileModel,
|
|
"QuantizeModel": QuantizeModel,
|
|
}
|