mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Experimental quantization support. Only Linux is meaningfully supported
This commit is contained in:
parent
5f26b76f59
commit
69e523b89d
@ -9,6 +9,7 @@ import threading
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
from contextlib import nullcontext
|
||||
from os import PathLike
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@ -28,7 +29,7 @@ from ..component_model.files import canonicalize_path
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, ExecutionContext
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
|
||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||
|
||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||
# various functions that are declared here. It should have been a context in the first place.
|
||||
@ -544,7 +545,7 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
@ -558,7 +559,7 @@ class PromptExecutor:
|
||||
|
||||
model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
{"nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
executed = set()
|
||||
@ -584,7 +585,7 @@ class PromptExecutor:
|
||||
execution_list.complete_node_execution()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
@ -607,6 +608,11 @@ class PromptExecutor:
|
||||
return self.history_result["outputs"] if self.history_result is not None else None
|
||||
|
||||
|
||||
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
|
||||
for _, node in prompt.items():
|
||||
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']]
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||
# todo: or, nodes should provide a way to validate their values
|
||||
|
||||
@ -248,6 +248,14 @@ class ModelPatcher(ModelManageable):
|
||||
else:
|
||||
return utils.get_attr(self.model, name)
|
||||
|
||||
@property
|
||||
def diffusion_model(self) -> BaseModel:
|
||||
return self.get_model_object("diffusion_model")
|
||||
|
||||
@diffusion_model.setter
|
||||
def diffusion_model(self, value: torch.nn.Module):
|
||||
self.add_object_patch("diffusion_model", value)
|
||||
|
||||
def model_patches_to(self, device):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
|
||||
@ -21,6 +21,7 @@ from .. import sd
|
||||
from .. import utils
|
||||
from .. import clip_vision as clip_vision_module
|
||||
from .. import model_management
|
||||
from .. import ops
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
@ -898,9 +899,8 @@ class UNETLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
def load_unet(self, unet_name, weight_dtype="default"):
|
||||
model_options = get_model_options_for_dtype(weight_dtype)
|
||||
|
||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
@ -127,6 +127,7 @@ class CustomNode(Protocol):
|
||||
FUNCTION: ClassVar[str]
|
||||
CATEGORY: ClassVar[str]
|
||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||
INFERENCE_MODE: Optional[ClassVar[bool]]
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, *args, **kwargs) -> str:
|
||||
|
||||
39
comfy/ops.py
39
comfy/ops.py
@ -60,12 +60,51 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
class SkipInit:
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
weight_function = None
|
||||
bias_function = None
|
||||
|
||||
class skip_init:
|
||||
class Linear(SkipInit, torch.nn.Linear):
|
||||
pass
|
||||
|
||||
class Conv1d(SkipInit, torch.nn.Conv1d):
|
||||
pass
|
||||
|
||||
class Conv2d(SkipInit, torch.nn.Conv2d):
|
||||
pass
|
||||
|
||||
class Conv3d(SkipInit, torch.nn.Conv3d):
|
||||
pass
|
||||
|
||||
class GroupNorm(SkipInit, torch.nn.GroupNorm):
|
||||
pass
|
||||
|
||||
class LayerNorm(SkipInit, torch.nn.LayerNorm):
|
||||
pass
|
||||
|
||||
class ConvTranspose2d(SkipInit, torch.nn.ConvTranspose2d):
|
||||
pass
|
||||
|
||||
class ConvTranspose1d(SkipInit, torch.nn.ConvTranspose1d):
|
||||
pass
|
||||
|
||||
class Embedding(SkipInit, torch.nn.Embedding):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def conv_nd(cls, dims, *args, **kwargs):
|
||||
if dims == 2:
|
||||
return cls.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return cls.Conv3d(*args, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
class disable_weight_init:
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
@ -10,7 +10,7 @@ from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
|
||||
class TorchCompileModel:
|
||||
class TorchCompileModel(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -27,11 +27,12 @@ class TorchCompileModel:
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
INFERENCE_MODE = False
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]:
|
||||
if object_patch is None:
|
||||
object_patch = DIFFUSION_MODEL
|
||||
compile_kwargs = {
|
||||
@ -39,6 +40,12 @@ class TorchCompileModel:
|
||||
"dynamic": dynamic,
|
||||
"backend": backend
|
||||
}
|
||||
if backend == "torch_tensorrt":
|
||||
compile_kwargs["options"] = {
|
||||
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
|
||||
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
|
||||
"enabled_precisions": {torch.float, torch.half}
|
||||
}
|
||||
if isinstance(model, ModelPatcher):
|
||||
m = model.clone()
|
||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||
@ -50,57 +57,72 @@ class TorchCompileModel:
|
||||
return model,
|
||||
|
||||
|
||||
_QUANTIZATION_STRATEGIES = [
|
||||
"torchao",
|
||||
"quanto",
|
||||
"torchao-autoquant"
|
||||
]
|
||||
|
||||
|
||||
class QuantizeModel(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {}),
|
||||
"strategy": (["torchao", "quanto"], {"default": "torchao"})
|
||||
"strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]})
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
INFERENCE_MODE = False
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
|
||||
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
|
||||
def warn_in_place(self, model: ModelPatcher):
|
||||
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
||||
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
||||
|
||||
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
||||
model = model.clone()
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
# todo: quantize quantizes in place, which is not desired
|
||||
|
||||
# default exclusions
|
||||
_unused_exclusions = {
|
||||
always_exclude_these = {
|
||||
"time_embedding.",
|
||||
"add_embedding.",
|
||||
"time_in.",
|
||||
"txt_in.",
|
||||
"vector_in.",
|
||||
"img_in.",
|
||||
"guidance_in.",
|
||||
"final_layer.",
|
||||
"time_in.in",
|
||||
"txt_in",
|
||||
"vector_in.in",
|
||||
"img_in",
|
||||
"guidance_in.in",
|
||||
"final_layer",
|
||||
}
|
||||
if strategy == "quanto":
|
||||
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
||||
self.warn_in_place(model)
|
||||
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
||||
exclusion_list = [
|
||||
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
||||
]
|
||||
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
||||
_in_place_fixme = unet
|
||||
elif strategy == "torchao":
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error
|
||||
elif "torchao" in strategy:
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
|
||||
model = model.clone()
|
||||
self.warn_in_place(model)
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
# todo: quantize quantizes in place, which is not desired
|
||||
|
||||
# def filter_fn(module: torch.nn.Module, name: str):
|
||||
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
|
||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||
_in_place_fixme = unet
|
||||
def filter(module: torch.nn.Module, fqn: str) -> bool:
|
||||
return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)
|
||||
|
||||
if "autoquant" in strategy:
|
||||
_in_place_fixme = autoquant(unet, error_on_unseen=False)
|
||||
else:
|
||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
|
||||
_in_place_fixme = unet
|
||||
else:
|
||||
raise ValueError(f"unknown strategy {strategy}")
|
||||
|
||||
|
||||
0
tests/quantization/__init__.py
Normal file
0
tests/quantization/__init__.py
Normal file
64
tests/quantization/test_torchao.py
Normal file
64
tests/quantization/test_torchao.py
Normal file
@ -0,0 +1,64 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.model_base import Flux
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.nodes.base_nodes import UNETLoader
|
||||
from comfy_extras.nodes.nodes_torch_compile import QuantizeModel
|
||||
|
||||
has_torchao = True
|
||||
try:
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_torchao = False
|
||||
|
||||
has_tensorrt = True
|
||||
try:
|
||||
from comfyui_tensorrt import STATIC_TRT_MODEL_CONVERSION
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
has_tensorrt = False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||
async def test_unit_torchao(checkpoint_name):
|
||||
# Downloads FLUX.1-dev and loads it using ComfyUI's models
|
||||
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||
model: ModelPatcher = model.clone()
|
||||
|
||||
transformer: Flux = model.get_model_object("diffusion_model")
|
||||
quantize_(transformer, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||
assert transformer is not None
|
||||
del transformer
|
||||
model_management.unload_all_models()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||
async def test_torchao_node(checkpoint_name, strategy):
|
||||
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||
model: ModelPatcher = model.clone()
|
||||
|
||||
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
|
||||
|
||||
transformer = quantized_model.get_model_object("diffusion_model")
|
||||
del transformer
|
||||
model_management.unload_all_models()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||
@pytest.mark.skipif(True, reason="not yet supported")
|
||||
async def test_torchao_into_tensorrt(checkpoint_name, strategy):
|
||||
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||
model: ModelPatcher = model.clone()
|
||||
model_management.load_models_gpu([model], force_full_load=True)
|
||||
model.diffusion_model = model.diffusion_model.to(memory_format=torch.channels_last)
|
||||
model.diffusion_model = torch.compile(model.diffusion_model, mode="max-autotune", fullgraph=True)
|
||||
|
||||
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
|
||||
|
||||
STATIC_TRT_MODEL_CONVERSION().convert(quantized_model, "test", 1, 1024, 1024, 1, 14)
|
||||
model_management.unload_all_models()
|
||||
Loading…
Reference in New Issue
Block a user