diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 1e28383cf..b9c05b95c 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -9,6 +9,7 @@ import threading import time import traceback import typing +from contextlib import nullcontext from os import PathLike from typing import List, Optional, Tuple @@ -28,7 +29,7 @@ from ..component_model.files import canonicalize_path from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace -from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions +from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # various functions that are declared here. It should have been a context in the first place. @@ -544,7 +545,7 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) - with torch.inference_mode(): + with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext(): dynamic_prompt = DynamicPrompt(prompt) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) for cache in self.caches.all: @@ -558,7 +559,7 @@ class PromptExecutor: model_management.cleanup_models(keep_clone_weights_loaded=True) self.add_message("execution_cached", - { "nodes": cached_nodes, "prompt_id": prompt_id}, + {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) pending_subgraph_results = {} executed = set() @@ -584,7 +585,7 @@ class PromptExecutor: execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break - self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) + self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False) ui_outputs = {} meta_outputs = {} @@ -607,6 +608,11 @@ class PromptExecutor: return self.history_result["outputs"] if self.history_result is not None else None +def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]: + for _, node in prompt.items(): + yield nodes.NODE_CLASS_MAPPINGS[node['class_type']] + + def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: # todo: this should check if LoadImage / LoadImageMask paths exist # todo: or, nodes should provide a way to validate their values diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 213dd3eb4..57a76bc43 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -248,6 +248,14 @@ class ModelPatcher(ModelManageable): else: return utils.get_attr(self.model, name) + @property + def diffusion_model(self) -> BaseModel: + return self.get_model_object("diffusion_model") + + @diffusion_model.setter + def diffusion_model(self, value: torch.nn.Module): + self.add_object_patch("diffusion_model", value) + def model_patches_to(self, device): to = self.model_options["transformer_options"] if "patches" in to: diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index bebf1e3a1..0a5118e36 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -21,6 +21,7 @@ from .. import sd from .. import utils from .. import clip_vision as clip_vision_module from .. import model_management +from .. import ops from ..cli_args import args from ..cmd import folder_paths, latent_preview @@ -898,9 +899,8 @@ class UNETLoader: CATEGORY = "advanced/loaders" - def load_unet(self, unet_name, weight_dtype): + def load_unet(self, unet_name, weight_dtype="default"): model_options = get_model_options_for_dtype(weight_dtype) - unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS) model = sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) diff --git a/comfy/nodes/package_typing.py b/comfy/nodes/package_typing.py index 005b76132..8bf97efce 100644 --- a/comfy/nodes/package_typing.py +++ b/comfy/nodes/package_typing.py @@ -127,6 +127,7 @@ class CustomNode(Protocol): FUNCTION: ClassVar[str] CATEGORY: ClassVar[str] OUTPUT_NODE: Optional[ClassVar[bool]] + INFERENCE_MODE: Optional[ClassVar[bool]] @classmethod def IS_CHANGED(cls, *args, **kwargs) -> str: diff --git a/comfy/ops.py b/comfy/ops.py index aab0e82d3..5f33758f5 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -60,12 +60,51 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): weight = s.weight_function(weight) return weight, bias +class SkipInit: + def reset_parameters(self): + return None class CastWeightBiasOp: comfy_cast_weights = False weight_function = None bias_function = None +class skip_init: + class Linear(SkipInit, torch.nn.Linear): + pass + + class Conv1d(SkipInit, torch.nn.Conv1d): + pass + + class Conv2d(SkipInit, torch.nn.Conv2d): + pass + + class Conv3d(SkipInit, torch.nn.Conv3d): + pass + + class GroupNorm(SkipInit, torch.nn.GroupNorm): + pass + + class LayerNorm(SkipInit, torch.nn.LayerNorm): + pass + + class ConvTranspose2d(SkipInit, torch.nn.ConvTranspose2d): + pass + + class ConvTranspose1d(SkipInit, torch.nn.ConvTranspose1d): + pass + + class Embedding(SkipInit, torch.nn.Embedding): + pass + + @classmethod + def conv_nd(cls, dims, *args, **kwargs): + if dims == 2: + return cls.Conv2d(*args, **kwargs) + elif dims == 3: + return cls.Conv3d(*args, **kwargs) + else: + raise ValueError(f"unsupported dimensions: {dims}") class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index 77f1f1294..33385d5eb 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -10,7 +10,7 @@ from comfy.nodes.package_typing import CustomNode, InputTypes DIFFUSION_MODEL = "diffusion_model" -class TorchCompileModel: +class TorchCompileModel(CustomNode): @classmethod def INPUT_TYPES(s): return { @@ -27,11 +27,12 @@ class TorchCompileModel: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" + INFERENCE_MODE = False CATEGORY = "_for_testing" EXPERIMENTAL = True - def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"): + def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]: if object_patch is None: object_patch = DIFFUSION_MODEL compile_kwargs = { @@ -39,6 +40,12 @@ class TorchCompileModel: "dynamic": dynamic, "backend": backend } + if backend == "torch_tensorrt": + compile_kwargs["options"] = { + # https://pytorch.org/TensorRT/dynamo/torch_compile.html + # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. + "enabled_precisions": {torch.float, torch.half} + } if isinstance(model, ModelPatcher): m = model.clone() m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) @@ -50,57 +57,72 @@ class TorchCompileModel: return model, +_QUANTIZATION_STRATEGIES = [ + "torchao", + "quanto", + "torchao-autoquant" +] + + class QuantizeModel(CustomNode): @classmethod def INPUT_TYPES(cls) -> InputTypes: return { "required": { "model": ("MODEL", {}), - "strategy": (["torchao", "quanto"], {"default": "torchao"}) + "strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]}) } } FUNCTION = "execute" CATEGORY = "_for_testing" EXPERIMENTAL = True + INFERENCE_MODE = False RETURN_TYPES = ("MODEL",) - def execute(self, model: ModelPatcher, strategy: str = "torchao"): + def warn_in_place(self, model: ModelPatcher): logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") - logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") + + def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]: model = model.clone() unet = model.get_model_object("diffusion_model") # todo: quantize quantizes in place, which is not desired # default exclusions - _unused_exclusions = { + always_exclude_these = { "time_embedding.", "add_embedding.", - "time_in.", - "txt_in.", - "vector_in.", - "img_in.", - "guidance_in.", - "final_layer.", + "time_in.in", + "txt_in", + "vector_in.in", + "img_in", + "guidance_in.in", + "final_layer", } if strategy == "quanto": + logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") + self.warn_in_place(model) from optimum.quanto import quantize, qint8 # pylint: disable=import-error exclusion_list = [ name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None ] quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) _in_place_fixme = unet - elif strategy == "torchao": - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error + elif "torchao" in strategy: + from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error model = model.clone() + self.warn_in_place(model) unet = model.get_model_object("diffusion_model") - # todo: quantize quantizes in place, which is not desired - # def filter_fn(module: torch.nn.Module, name: str): - # return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions) - quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) - _in_place_fixme = unet + def filter(module: torch.nn.Module, fqn: str) -> bool: + return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these) + + if "autoquant" in strategy: + _in_place_fixme = autoquant(unet, error_on_unseen=False) + else: + quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter) + _in_place_fixme = unet else: raise ValueError(f"unknown strategy {strategy}") diff --git a/tests/quantization/__init__.py b/tests/quantization/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/quantization/test_torchao.py b/tests/quantization/test_torchao.py new file mode 100644 index 000000000..7c54cdbaa --- /dev/null +++ b/tests/quantization/test_torchao.py @@ -0,0 +1,64 @@ +import pytest +import torch + +from comfy import model_management +from comfy.model_base import Flux +from comfy.model_patcher import ModelPatcher +from comfy.nodes.base_nodes import UNETLoader +from comfy_extras.nodes.nodes_torch_compile import QuantizeModel + +has_torchao = True +try: + from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight +except (ImportError, ModuleNotFoundError): + has_torchao = False + +has_tensorrt = True +try: + from comfyui_tensorrt import STATIC_TRT_MODEL_CONVERSION +except (ImportError, ModuleNotFoundError): + has_tensorrt = False + + +@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"]) +@pytest.mark.skipif(not has_torchao, reason="torchao not installed") +async def test_unit_torchao(checkpoint_name): + # Downloads FLUX.1-dev and loads it using ComfyUI's models + model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default") + model: ModelPatcher = model.clone() + + transformer: Flux = model.get_model_object("diffusion_model") + quantize_(transformer, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) + assert transformer is not None + del transformer + model_management.unload_all_models() + + +@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"]) +@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"]) +@pytest.mark.skipif(not has_torchao, reason="torchao not installed") +async def test_torchao_node(checkpoint_name, strategy): + model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default") + model: ModelPatcher = model.clone() + + quantized_model, = QuantizeModel().execute(model, strategy=strategy) + + transformer = quantized_model.get_model_object("diffusion_model") + del transformer + model_management.unload_all_models() + + +@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"]) +@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"]) +@pytest.mark.skipif(True, reason="not yet supported") +async def test_torchao_into_tensorrt(checkpoint_name, strategy): + model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default") + model: ModelPatcher = model.clone() + model_management.load_models_gpu([model], force_full_load=True) + model.diffusion_model = model.diffusion_model.to(memory_format=torch.channels_last) + model.diffusion_model = torch.compile(model.diffusion_model, mode="max-autotune", fullgraph=True) + + quantized_model, = QuantizeModel().execute(model, strategy=strategy) + + STATIC_TRT_MODEL_CONVERSION().convert(quantized_model, "test", 1, 1024, 1024, 1, 14) + model_management.unload_all_models()