mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Experimental quantization support. Only Linux is meaningfully supported
This commit is contained in:
parent
5f26b76f59
commit
69e523b89d
@ -9,6 +9,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
from contextlib import nullcontext
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@ -28,7 +29,7 @@ from ..component_model.files import canonicalize_path
|
|||||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||||
from ..execution_context import new_execution_context, ExecutionContext
|
from ..execution_context import new_execution_context, ExecutionContext
|
||||||
from ..nodes.package import import_all_nodes_in_workspace
|
from ..nodes.package import import_all_nodes_in_workspace
|
||||||
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
|
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
|
||||||
|
|
||||||
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
|
||||||
# various functions that are declared here. It should have been a context in the first place.
|
# various functions that are declared here. It should have been a context in the first place.
|
||||||
@ -544,7 +545,7 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
@ -558,7 +559,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
model_management.cleanup_models(keep_clone_weights_loaded=True)
|
model_management.cleanup_models(keep_clone_weights_loaded=True)
|
||||||
self.add_message("execution_cached",
|
self.add_message("execution_cached",
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
{"nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
@ -584,7 +585,7 @@ class PromptExecutor:
|
|||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
@ -607,6 +608,11 @@ class PromptExecutor:
|
|||||||
return self.history_result["outputs"] if self.history_result is not None else None
|
return self.history_result["outputs"] if self.history_result is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
|
||||||
|
for _, node in prompt.items():
|
||||||
|
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']]
|
||||||
|
|
||||||
|
|
||||||
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
|
||||||
# todo: this should check if LoadImage / LoadImageMask paths exist
|
# todo: this should check if LoadImage / LoadImageMask paths exist
|
||||||
# todo: or, nodes should provide a way to validate their values
|
# todo: or, nodes should provide a way to validate their values
|
||||||
|
|||||||
@ -248,6 +248,14 @@ class ModelPatcher(ModelManageable):
|
|||||||
else:
|
else:
|
||||||
return utils.get_attr(self.model, name)
|
return utils.get_attr(self.model, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def diffusion_model(self) -> BaseModel:
|
||||||
|
return self.get_model_object("diffusion_model")
|
||||||
|
|
||||||
|
@diffusion_model.setter
|
||||||
|
def diffusion_model(self, value: torch.nn.Module):
|
||||||
|
self.add_object_patch("diffusion_model", value)
|
||||||
|
|
||||||
def model_patches_to(self, device):
|
def model_patches_to(self, device):
|
||||||
to = self.model_options["transformer_options"]
|
to = self.model_options["transformer_options"]
|
||||||
if "patches" in to:
|
if "patches" in to:
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from .. import sd
|
|||||||
from .. import utils
|
from .. import utils
|
||||||
from .. import clip_vision as clip_vision_module
|
from .. import clip_vision as clip_vision_module
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
|
from .. import ops
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
@ -898,9 +899,8 @@ class UNETLoader:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
def load_unet(self, unet_name, weight_dtype="default"):
|
||||||
model_options = get_model_options_for_dtype(weight_dtype)
|
model_options = get_model_options_for_dtype(weight_dtype)
|
||||||
|
|
||||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|||||||
@ -127,6 +127,7 @@ class CustomNode(Protocol):
|
|||||||
FUNCTION: ClassVar[str]
|
FUNCTION: ClassVar[str]
|
||||||
CATEGORY: ClassVar[str]
|
CATEGORY: ClassVar[str]
|
||||||
OUTPUT_NODE: Optional[ClassVar[bool]]
|
OUTPUT_NODE: Optional[ClassVar[bool]]
|
||||||
|
INFERENCE_MODE: Optional[ClassVar[bool]]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, *args, **kwargs) -> str:
|
def IS_CHANGED(cls, *args, **kwargs) -> str:
|
||||||
|
|||||||
39
comfy/ops.py
39
comfy/ops.py
@ -60,12 +60,51 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
|
class SkipInit:
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = None
|
weight_function = None
|
||||||
bias_function = None
|
bias_function = None
|
||||||
|
|
||||||
|
class skip_init:
|
||||||
|
class Linear(SkipInit, torch.nn.Linear):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Conv1d(SkipInit, torch.nn.Conv1d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Conv2d(SkipInit, torch.nn.Conv2d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Conv3d(SkipInit, torch.nn.Conv3d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GroupNorm(SkipInit, torch.nn.GroupNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LayerNorm(SkipInit, torch.nn.LayerNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConvTranspose2d(SkipInit, torch.nn.ConvTranspose2d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ConvTranspose1d(SkipInit, torch.nn.ConvTranspose1d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Embedding(SkipInit, torch.nn.Embedding):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def conv_nd(cls, dims, *args, **kwargs):
|
||||||
|
if dims == 2:
|
||||||
|
return cls.Conv2d(*args, **kwargs)
|
||||||
|
elif dims == 3:
|
||||||
|
return cls.Conv3d(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from comfy.nodes.package_typing import CustomNode, InputTypes
|
|||||||
DIFFUSION_MODEL = "diffusion_model"
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel:
|
class TorchCompileModel(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -27,11 +27,12 @@ class TorchCompileModel:
|
|||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
|
INFERENCE_MODE = False
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]:
|
||||||
if object_patch is None:
|
if object_patch is None:
|
||||||
object_patch = DIFFUSION_MODEL
|
object_patch = DIFFUSION_MODEL
|
||||||
compile_kwargs = {
|
compile_kwargs = {
|
||||||
@ -39,6 +40,12 @@ class TorchCompileModel:
|
|||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
"backend": backend
|
"backend": backend
|
||||||
}
|
}
|
||||||
|
if backend == "torch_tensorrt":
|
||||||
|
compile_kwargs["options"] = {
|
||||||
|
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
|
||||||
|
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
|
||||||
|
"enabled_precisions": {torch.float, torch.half}
|
||||||
|
}
|
||||||
if isinstance(model, ModelPatcher):
|
if isinstance(model, ModelPatcher):
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||||
@ -50,57 +57,72 @@ class TorchCompileModel:
|
|||||||
return model,
|
return model,
|
||||||
|
|
||||||
|
|
||||||
|
_QUANTIZATION_STRATEGIES = [
|
||||||
|
"torchao",
|
||||||
|
"quanto",
|
||||||
|
"torchao-autoquant"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class QuantizeModel(CustomNode):
|
class QuantizeModel(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"model": ("MODEL", {}),
|
"model": ("MODEL", {}),
|
||||||
"strategy": (["torchao", "quanto"], {"default": "torchao"})
|
"strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
INFERENCE_MODE = False
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
|
||||||
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
|
def warn_in_place(self, model: ModelPatcher):
|
||||||
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
||||||
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
|
||||||
|
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
unet = model.get_model_object("diffusion_model")
|
unet = model.get_model_object("diffusion_model")
|
||||||
# todo: quantize quantizes in place, which is not desired
|
# todo: quantize quantizes in place, which is not desired
|
||||||
|
|
||||||
# default exclusions
|
# default exclusions
|
||||||
_unused_exclusions = {
|
always_exclude_these = {
|
||||||
"time_embedding.",
|
"time_embedding.",
|
||||||
"add_embedding.",
|
"add_embedding.",
|
||||||
"time_in.",
|
"time_in.in",
|
||||||
"txt_in.",
|
"txt_in",
|
||||||
"vector_in.",
|
"vector_in.in",
|
||||||
"img_in.",
|
"img_in",
|
||||||
"guidance_in.",
|
"guidance_in.in",
|
||||||
"final_layer.",
|
"final_layer",
|
||||||
}
|
}
|
||||||
if strategy == "quanto":
|
if strategy == "quanto":
|
||||||
|
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
||||||
|
self.warn_in_place(model)
|
||||||
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
|
||||||
exclusion_list = [
|
exclusion_list = [
|
||||||
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
||||||
]
|
]
|
||||||
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
||||||
_in_place_fixme = unet
|
_in_place_fixme = unet
|
||||||
elif strategy == "torchao":
|
elif "torchao" in strategy:
|
||||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error
|
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
|
self.warn_in_place(model)
|
||||||
unet = model.get_model_object("diffusion_model")
|
unet = model.get_model_object("diffusion_model")
|
||||||
# todo: quantize quantizes in place, which is not desired
|
|
||||||
|
|
||||||
# def filter_fn(module: torch.nn.Module, name: str):
|
def filter(module: torch.nn.Module, fqn: str) -> bool:
|
||||||
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
|
return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)
|
||||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
|
||||||
_in_place_fixme = unet
|
if "autoquant" in strategy:
|
||||||
|
_in_place_fixme = autoquant(unet, error_on_unseen=False)
|
||||||
|
else:
|
||||||
|
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
|
||||||
|
_in_place_fixme = unet
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown strategy {strategy}")
|
raise ValueError(f"unknown strategy {strategy}")
|
||||||
|
|
||||||
|
|||||||
0
tests/quantization/__init__.py
Normal file
0
tests/quantization/__init__.py
Normal file
64
tests/quantization/test_torchao.py
Normal file
64
tests/quantization/test_torchao.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy import model_management
|
||||||
|
from comfy.model_base import Flux
|
||||||
|
from comfy.model_patcher import ModelPatcher
|
||||||
|
from comfy.nodes.base_nodes import UNETLoader
|
||||||
|
from comfy_extras.nodes.nodes_torch_compile import QuantizeModel
|
||||||
|
|
||||||
|
has_torchao = True
|
||||||
|
try:
|
||||||
|
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
has_torchao = False
|
||||||
|
|
||||||
|
has_tensorrt = True
|
||||||
|
try:
|
||||||
|
from comfyui_tensorrt import STATIC_TRT_MODEL_CONVERSION
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
has_tensorrt = False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||||
|
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||||
|
async def test_unit_torchao(checkpoint_name):
|
||||||
|
# Downloads FLUX.1-dev and loads it using ComfyUI's models
|
||||||
|
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||||
|
model: ModelPatcher = model.clone()
|
||||||
|
|
||||||
|
transformer: Flux = model.get_model_object("diffusion_model")
|
||||||
|
quantize_(transformer, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||||
|
assert transformer is not None
|
||||||
|
del transformer
|
||||||
|
model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||||
|
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||||
|
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
|
||||||
|
async def test_torchao_node(checkpoint_name, strategy):
|
||||||
|
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||||
|
model: ModelPatcher = model.clone()
|
||||||
|
|
||||||
|
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
|
||||||
|
|
||||||
|
transformer = quantized_model.get_model_object("diffusion_model")
|
||||||
|
del transformer
|
||||||
|
model_management.unload_all_models()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
|
||||||
|
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
|
||||||
|
@pytest.mark.skipif(True, reason="not yet supported")
|
||||||
|
async def test_torchao_into_tensorrt(checkpoint_name, strategy):
|
||||||
|
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
|
||||||
|
model: ModelPatcher = model.clone()
|
||||||
|
model_management.load_models_gpu([model], force_full_load=True)
|
||||||
|
model.diffusion_model = model.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
model.diffusion_model = torch.compile(model.diffusion_model, mode="max-autotune", fullgraph=True)
|
||||||
|
|
||||||
|
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
|
||||||
|
|
||||||
|
STATIC_TRT_MODEL_CONVERSION().convert(quantized_model, "test", 1, 1024, 1024, 1, 14)
|
||||||
|
model_management.unload_all_models()
|
||||||
Loading…
Reference in New Issue
Block a user