import logging import os from pathlib import Path from typing import Union import torch import torch._inductor.codecache from torch.nn import LayerNorm from comfy import model_management from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes DIFFUSION_MODEL = "diffusion_model" TORCH_COMPILE_BACKENDS = [ "inductor", "torch_tensorrt", "onnxrt", "cudagraphs", "openxla", "tvm" ] TORCH_COMPILE_MODES = [ "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs" ] # fix torch bug on windows _old_write_atomic = torch._inductor.codecache.write_atomic def write_atomic( path: str, content: Union[str, bytes], make_dirs: bool = False ) -> None: if Path(path).exists(): os.remove(path) _old_write_atomic(path, content, make_dirs=make_dirs) torch._inductor.codecache.write_atomic = write_atomic class TorchCompileModel(CustomNode): @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), }, "optional": { "object_patch": ("STRING", {"default": DIFFUSION_MODEL}), "fullgraph": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), "mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}) } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" # INFERENCE_MODE = False CATEGORY = "_for_testing" EXPERIMENTAL = True def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]: if object_patch is None: object_patch = DIFFUSION_MODEL compile_kwargs = { "fullgraph": fullgraph, "dynamic": dynamic, "backend": backend, "mode": mode, } try: if backend == "torch_tensorrt": compile_kwargs["options"] = { # https://pytorch.org/TensorRT/dynamo/torch_compile.html # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. "enabled_precisions": {torch.float, torch.half} } if isinstance(model, ModelPatcher): m = model.clone() m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) return (m,) elif isinstance(model, torch.nn.Module): return torch.compile(model=model, **compile_kwargs), else: logging.warning("Encountered a model that cannot be compiled") return model, except OSError: torch._inductor.utils.clear_inductor_caches() raise _QUANTIZATION_STRATEGIES = [ "torchao", "torchao-autoquant", "quanto", ] class QuantizeModel(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "model": ("MODEL", {}), "strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]}) } } FUNCTION = "execute" CATEGORY = "_for_testing" EXPERIMENTAL = True # INFERENCE_MODE = False RETURN_TYPES = ("MODEL",) def warn_in_place(self, model: ModelPatcher): logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: model = model.clone() unet = model.get_model_object("diffusion_model") # todo: quantize quantizes in place, which is not desired # default exclusions always_exclude_these = { "time_embedding.", "add_embedding.", "time_in.in", "txt_in", "vector_in.in", "img_in", "guidance_in.in", "final_layer", } if strategy == "quanto": logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") self.warn_in_place(model) from optimum.quanto import quantize, qint8 # pylint: disable=import-error exclusion_list = [ name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None ] quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) _in_place_fixme = unet elif "torchao" in strategy: from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass, autoquant # pylint: disable=import-error model = model.clone() self.warn_in_place(model) unet = model.get_model_object("diffusion_model") def filter(module: torch.nn.Module, fqn: str) -> bool: return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these) if "autoquant" in strategy: _in_place_fixme = autoquant(unet, error_on_unseen=False) else: quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), set_inductor_config=False) _in_place_fixme = unet unwrap_tensor_subclass(_in_place_fixme) else: raise ValueError(f"unknown strategy {strategy}") model.add_object_patch("diffusion_model", _in_place_fixme) return model, NODE_CLASS_MAPPINGS = { "TorchCompileModel": TorchCompileModel, "QuantizeModel": QuantizeModel, }