import logging import os from pathlib import Path from typing import Union, Callable import torch import torch._inductor.codecache from torch.nn import LayerNorm from comfy import model_management from comfy.language.transformers_model_management import TransformersManagedModel from comfy.model_management_types import HooksSupport from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.sd import VAE from comfy_api.torch_helpers import set_torch_compile_wrapper logger = logging.getLogger(__name__) DIFFUSION_MODEL = "diffusion_model" TORCH_COMPILE_BACKENDS = [ "inductor", "torch_tensorrt", "onnxrt", "cudagraphs", "openxla", "tvm" ] TORCH_COMPILE_MODES = [ "default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs" ] # fix torch bug on windows _old_write_atomic = torch._inductor.codecache.write_atomic def write_atomic( path: str, content: Union[str, bytes], make_dirs: bool = False ) -> None: if Path(path).exists(): os.remove(path) _old_write_atomic(path, content, make_dirs=make_dirs) torch._inductor.codecache.write_atomic = write_atomic # torch._inductor.utils.is_big_gpu = lambda *args: True def skip_torch_compile_dict(guard_entries): return [("transformer_options" not in entry.name) for entry in guard_entries] class TorchCompileModel(CustomNode): @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL,VAE",), }, "optional": { "object_patch": ("STRING", {"default": ""}), "fullgraph": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), "mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}), "torch_tensorrt_optimization_level": ("INT", {"default": 3, "min": 1, "max": 5}) } } RETURN_TYPES = ("MODEL,VAE",) FUNCTION = "patch" RETURN_NAMES = ("model or vae",) CATEGORY = "_for_testing" EXPERIMENTAL = True def patch(self, model: ModelPatcher | VAE | torch.nn.Module, object_patch: str | None = "", fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[Callable]: compile_kwargs = { "fullgraph": fullgraph, "dynamic": dynamic, "backend": backend, "mode": mode, } try: if backend == "torch_tensorrt": try: import torch_tensorrt except (ImportError, ModuleNotFoundError): logger.error(f"Install torch-tensorrt and modelopt") raise compile_kwargs["options"] = { # https://pytorch.org/TensorRT/dynamo/torch_compile.html # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. "enabled_precisions": {torch.float, torch.half, torch.float8_e5m2}, "optimization_level": torch_tensorrt_optimization_level, "cache_built_engines": True, "reuse_cached_engines": True, "enable_weight_streaming": True, "make_refittable": True, } del compile_kwargs["mode"] if isinstance(model, HooksSupport): to_return = model.clone() object_patches = [p.strip() for p in object_patch.split(",")] patcher: ModelPatcher if isinstance(to_return, VAE): patcher = to_return.patcher object_patches = ["encoder", "decoder"] else: patcher = to_return if object_patch is None or len(object_patches) == 0 or len(object_patches) == 1 and object_patches[0].strip() == "": object_patches = [DIFFUSION_MODEL] set_torch_compile_wrapper(patcher, keys=object_patches, options={"guard_filter_fn": skip_torch_compile_dict}, **compile_kwargs) return to_return, elif isinstance(model, torch.nn.Module): model_management.unload_all_models() model.to(device=model_management.get_torch_device()) res = torch.compile(model=model, **compile_kwargs), model.to(device=model_management.unet_offload_device()) return res, else: logger.warning(f"Encountered a model {model} that cannot be compiled") return model, except OSError as os_error: try: torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member except Exception: pass raise os_error except Exception as exc_info: try: torch._inductor.utils.clear_inductor_caches() # pylint: disable=no-member except Exception: pass logger.error(f"An exception occurred while trying to compile {str(model)}, gracefully skipping compilation", exc_info=exc_info) return model, _QUANTIZATION_STRATEGIES = [ "torchao", "torchao-autoquant", "quanto", ] class QuantizeModel(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "model": ("MODEL", {}), "strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]}) } } FUNCTION = "execute" CATEGORY = "_for_testing" EXPERIMENTAL = True RETURN_TYPES = ("MODEL",) def warn_in_place(self, model: ModelPatcher): logger.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: model = model.clone() model.patch_model(force_patch_weights=True) unet = model.diffusion_model # todo: quantize quantizes in place, which is not desired # default exclusions always_exclude_these = { "time_embedding.", "add_embedding.", "time_in.in", "txt_in", "vector_in.in", "img_in", "guidance_in.in", "final_layer", } if strategy == "quanto": logger.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") self.warn_in_place(model) from optimum.quanto import quantize, qint8 # pylint: disable=import-error exclusion_list = [ name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None ] quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) _in_place_fixme = unet elif "torchao" in strategy: from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error from torchao.utils import unwrap_tensor_subclass # pylint: disable=import-error self.warn_in_place(model) model_management.load_models_gpu([model]) def filter(module: torch.nn.Module, fqn: str) -> bool: return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these) if "autoquant" in strategy: _in_place_fixme = autoquant(unet, error_on_unseen=False) else: quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter) _in_place_fixme = unet unwrap_tensor_subclass(_in_place_fixme) else: raise ValueError(f"unknown strategy {strategy}") model.add_object_patch("diffusion_model", _in_place_fixme) return model, NODE_CLASS_MAPPINGS = { "TorchCompileModel": TorchCompileModel, "QuantizeModel": QuantizeModel, }