Experimental quantization support. Only Linux is meaningfully supported

This commit is contained in:
doctorpangloss 2024-10-10 13:43:06 -07:00
parent 5f26b76f59
commit 69e523b89d
8 changed files with 165 additions and 25 deletions

View File

@ -9,6 +9,7 @@ import threading
import time
import traceback
import typing
from contextlib import nullcontext
from os import PathLike
from typing import List, Optional, Tuple
@ -28,7 +29,7 @@ from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place.
@ -544,7 +545,7 @@ class PromptExecutor:
self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode():
with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
@ -558,7 +559,7 @@ class PromptExecutor:
model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id},
{"nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
executed = set()
@ -584,7 +585,7 @@ class PromptExecutor:
execution_list.complete_node_execution()
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False)
ui_outputs = {}
meta_outputs = {}
@ -607,6 +608,11 @@ class PromptExecutor:
return self.history_result["outputs"] if self.history_result is not None else None
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
for _, node in prompt.items():
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']]
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
# todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values

View File

@ -248,6 +248,14 @@ class ModelPatcher(ModelManageable):
else:
return utils.get_attr(self.model, name)
@property
def diffusion_model(self) -> BaseModel:
return self.get_model_object("diffusion_model")
@diffusion_model.setter
def diffusion_model(self, value: torch.nn.Module):
self.add_object_patch("diffusion_model", value)
def model_patches_to(self, device):
to = self.model_options["transformer_options"]
if "patches" in to:

View File

@ -21,6 +21,7 @@ from .. import sd
from .. import utils
from .. import clip_vision as clip_vision_module
from .. import model_management
from .. import ops
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
@ -898,9 +899,8 @@ class UNETLoader:
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
def load_unet(self, unet_name, weight_dtype="default"):
model_options = get_model_options_for_dtype(weight_dtype)
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)

View File

@ -127,6 +127,7 @@ class CustomNode(Protocol):
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
INFERENCE_MODE: Optional[ClassVar[bool]]
@classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str:

View File

@ -60,12 +60,51 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = s.weight_function(weight)
return weight, bias
class SkipInit:
def reset_parameters(self):
return None
class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = None
bias_function = None
class skip_init:
class Linear(SkipInit, torch.nn.Linear):
pass
class Conv1d(SkipInit, torch.nn.Conv1d):
pass
class Conv2d(SkipInit, torch.nn.Conv2d):
pass
class Conv3d(SkipInit, torch.nn.Conv3d):
pass
class GroupNorm(SkipInit, torch.nn.GroupNorm):
pass
class LayerNorm(SkipInit, torch.nn.LayerNorm):
pass
class ConvTranspose2d(SkipInit, torch.nn.ConvTranspose2d):
pass
class ConvTranspose1d(SkipInit, torch.nn.ConvTranspose1d):
pass
class Embedding(SkipInit, torch.nn.Embedding):
pass
@classmethod
def conv_nd(cls, dims, *args, **kwargs):
if dims == 2:
return cls.Conv2d(*args, **kwargs)
elif dims == 3:
return cls.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):

View File

@ -10,7 +10,7 @@ from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel:
class TorchCompileModel(CustomNode):
@classmethod
def INPUT_TYPES(s):
return {
@ -27,11 +27,12 @@ class TorchCompileModel:
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
INFERENCE_MODE = False
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]:
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
@ -39,6 +40,12 @@ class TorchCompileModel:
"dynamic": dynamic,
"backend": backend
}
if backend == "torch_tensorrt":
compile_kwargs["options"] = {
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
"enabled_precisions": {torch.float, torch.half}
}
if isinstance(model, ModelPatcher):
m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
@ -50,57 +57,72 @@ class TorchCompileModel:
return model,
_QUANTIZATION_STRATEGIES = [
"torchao",
"quanto",
"torchao-autoquant"
]
class QuantizeModel(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
"strategy": (["torchao", "quanto"], {"default": "torchao"})
"strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]})
}
}
FUNCTION = "execute"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
INFERENCE_MODE = False
RETURN_TYPES = ("MODEL",)
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
def warn_in_place(self, model: ModelPatcher):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# default exclusions
_unused_exclusions = {
always_exclude_these = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
"time_in.in",
"txt_in",
"vector_in.in",
"img_in",
"guidance_in.in",
"final_layer",
}
if strategy == "quanto":
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
self.warn_in_place(model)
from optimum.quanto import quantize, qint8 # pylint: disable=import-error
exclusion_list = [
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
]
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet
elif strategy == "torchao":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error
elif "torchao" in strategy:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
model = model.clone()
self.warn_in_place(model)
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# def filter_fn(module: torch.nn.Module, name: str):
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
_in_place_fixme = unet
def filter(module: torch.nn.Module, fqn: str) -> bool:
return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)
if "autoquant" in strategy:
_in_place_fixme = autoquant(unet, error_on_unseen=False)
else:
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
_in_place_fixme = unet
else:
raise ValueError(f"unknown strategy {strategy}")

View File

View File

@ -0,0 +1,64 @@
import pytest
import torch
from comfy import model_management
from comfy.model_base import Flux
from comfy.model_patcher import ModelPatcher
from comfy.nodes.base_nodes import UNETLoader
from comfy_extras.nodes.nodes_torch_compile import QuantizeModel
has_torchao = True
try:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
except (ImportError, ModuleNotFoundError):
has_torchao = False
has_tensorrt = True
try:
from comfyui_tensorrt import STATIC_TRT_MODEL_CONVERSION
except (ImportError, ModuleNotFoundError):
has_tensorrt = False
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
async def test_unit_torchao(checkpoint_name):
# Downloads FLUX.1-dev and loads it using ComfyUI's models
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
transformer: Flux = model.get_model_object("diffusion_model")
quantize_(transformer, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
assert transformer is not None
del transformer
model_management.unload_all_models()
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
async def test_torchao_node(checkpoint_name, strategy):
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
transformer = quantized_model.get_model_object("diffusion_model")
del transformer
model_management.unload_all_models()
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(True, reason="not yet supported")
async def test_torchao_into_tensorrt(checkpoint_name, strategy):
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
model_management.load_models_gpu([model], force_full_load=True)
model.diffusion_model = model.diffusion_model.to(memory_format=torch.channels_last)
model.diffusion_model = torch.compile(model.diffusion_model, mode="max-autotune", fullgraph=True)
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
STATIC_TRT_MODEL_CONVERSION().convert(quantized_model, "test", 1, 1024, 1024, 1, 14)
model_management.unload_all_models()