Experimental quantization support. Only Linux is meaningfully supported

This commit is contained in:
doctorpangloss 2024-10-10 13:43:06 -07:00
parent 5f26b76f59
commit 69e523b89d
8 changed files with 165 additions and 25 deletions

View File

@ -9,6 +9,7 @@ import threading
import time import time
import traceback import traceback
import typing import typing
from contextlib import nullcontext
from os import PathLike from os import PathLike
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -28,7 +29,7 @@ from ..component_model.files import canonicalize_path
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace from ..nodes.package import import_all_nodes_in_workspace
from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions from ..nodes.package_typing import ExportedNodes, InputTypeSpec, FloatSpecOptions, IntSpecOptions, CustomNode
# ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the # ideally this would be passed in from main, but the way this is authored, we can't easily pass nodes down to the
# various functions that are declared here. It should have been a context in the first place. # various functions that are declared here. It should have been a context in the first place.
@ -544,7 +545,7 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", {"prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): with torch.inference_mode() if all(not hasattr(node_class, "INFERENCE_MODE") or node_class.INFERENCE_MODE for node_class in iterate_obj_classes(prompt)) else nullcontext():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs) is_changed_cache = IsChangedCache(dynamic_prompt, self.caches.outputs)
for cache in self.caches.all: for cache in self.caches.all:
@ -558,7 +559,7 @@ class PromptExecutor:
model_management.cleanup_models(keep_clone_weights_loaded=True) model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached", self.add_message("execution_cached",
{ "nodes": cached_nodes, "prompt_id": prompt_id}, {"nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
executed = set() executed = set()
@ -584,7 +585,7 @@ class PromptExecutor:
execution_list.complete_node_execution() execution_list.complete_node_execution()
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False)
ui_outputs = {} ui_outputs = {}
meta_outputs = {} meta_outputs = {}
@ -607,6 +608,11 @@ class PromptExecutor:
return self.history_result["outputs"] if self.history_result is not None else None return self.history_result["outputs"] if self.history_result is not None else None
def iterate_obj_classes(prompt: dict[str, typing.Any]) -> typing.Generator[typing.Type[CustomNode], None, None]:
for _, node in prompt.items():
yield nodes.NODE_CLASS_MAPPINGS[node['class_type']]
def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple: def validate_inputs(prompt, item, validated: typing.Dict[str, ValidateInputsTuple]) -> ValidateInputsTuple:
# todo: this should check if LoadImage / LoadImageMask paths exist # todo: this should check if LoadImage / LoadImageMask paths exist
# todo: or, nodes should provide a way to validate their values # todo: or, nodes should provide a way to validate their values

View File

@ -248,6 +248,14 @@ class ModelPatcher(ModelManageable):
else: else:
return utils.get_attr(self.model, name) return utils.get_attr(self.model, name)
@property
def diffusion_model(self) -> BaseModel:
return self.get_model_object("diffusion_model")
@diffusion_model.setter
def diffusion_model(self, value: torch.nn.Module):
self.add_object_patch("diffusion_model", value)
def model_patches_to(self, device): def model_patches_to(self, device):
to = self.model_options["transformer_options"] to = self.model_options["transformer_options"]
if "patches" in to: if "patches" in to:

View File

@ -21,6 +21,7 @@ from .. import sd
from .. import utils from .. import utils
from .. import clip_vision as clip_vision_module from .. import clip_vision as clip_vision_module
from .. import model_management from .. import model_management
from .. import ops
from ..cli_args import args from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
@ -898,9 +899,8 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name, weight_dtype="default"):
model_options = get_model_options_for_dtype(weight_dtype) model_options = get_model_options_for_dtype(weight_dtype)
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS) unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options) model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)

View File

@ -127,6 +127,7 @@ class CustomNode(Protocol):
FUNCTION: ClassVar[str] FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str] CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]] OUTPUT_NODE: Optional[ClassVar[bool]]
INFERENCE_MODE: Optional[ClassVar[bool]]
@classmethod @classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str: def IS_CHANGED(cls, *args, **kwargs) -> str:

View File

@ -60,12 +60,51 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
weight = s.weight_function(weight) weight = s.weight_function(weight)
return weight, bias return weight, bias
class SkipInit:
def reset_parameters(self):
return None
class CastWeightBiasOp: class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
weight_function = None weight_function = None
bias_function = None bias_function = None
class skip_init:
class Linear(SkipInit, torch.nn.Linear):
pass
class Conv1d(SkipInit, torch.nn.Conv1d):
pass
class Conv2d(SkipInit, torch.nn.Conv2d):
pass
class Conv3d(SkipInit, torch.nn.Conv3d):
pass
class GroupNorm(SkipInit, torch.nn.GroupNorm):
pass
class LayerNorm(SkipInit, torch.nn.LayerNorm):
pass
class ConvTranspose2d(SkipInit, torch.nn.ConvTranspose2d):
pass
class ConvTranspose1d(SkipInit, torch.nn.ConvTranspose1d):
pass
class Embedding(SkipInit, torch.nn.Embedding):
pass
@classmethod
def conv_nd(cls, dims, *args, **kwargs):
if dims == 2:
return cls.Conv2d(*args, **kwargs)
elif dims == 3:
return cls.Conv3d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):

View File

@ -10,7 +10,7 @@ from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model" DIFFUSION_MODEL = "diffusion_model"
class TorchCompileModel: class TorchCompileModel(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -27,11 +27,12 @@ class TorchCompileModel:
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
INFERENCE_MODE = False
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"): def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]:
if object_patch is None: if object_patch is None:
object_patch = DIFFUSION_MODEL object_patch = DIFFUSION_MODEL
compile_kwargs = { compile_kwargs = {
@ -39,6 +40,12 @@ class TorchCompileModel:
"dynamic": dynamic, "dynamic": dynamic,
"backend": backend "backend": backend
} }
if backend == "torch_tensorrt":
compile_kwargs["options"] = {
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
"enabled_precisions": {torch.float, torch.half}
}
if isinstance(model, ModelPatcher): if isinstance(model, ModelPatcher):
m = model.clone() m = model.clone()
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
@ -50,57 +57,72 @@ class TorchCompileModel:
return model, return model,
_QUANTIZATION_STRATEGIES = [
"torchao",
"quanto",
"torchao-autoquant"
]
class QuantizeModel(CustomNode): class QuantizeModel(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"model": ("MODEL", {}), "model": ("MODEL", {}),
"strategy": (["torchao", "quanto"], {"default": "torchao"}) "strategy": (_QUANTIZATION_STRATEGIES, {"default": _QUANTIZATION_STRATEGIES[0]})
} }
} }
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
INFERENCE_MODE = False
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
def execute(self, model: ModelPatcher, strategy: str = "torchao"): def warn_in_place(self, model: ModelPatcher):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
def execute(self, model: ModelPatcher, strategy: str = _QUANTIZATION_STRATEGIES[0]) -> tuple[ModelPatcher]:
model = model.clone() model = model.clone()
unet = model.get_model_object("diffusion_model") unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired # todo: quantize quantizes in place, which is not desired
# default exclusions # default exclusions
_unused_exclusions = { always_exclude_these = {
"time_embedding.", "time_embedding.",
"add_embedding.", "add_embedding.",
"time_in.", "time_in.in",
"txt_in.", "txt_in",
"vector_in.", "vector_in.in",
"img_in.", "img_in",
"guidance_in.", "guidance_in.in",
"final_layer.", "final_layer",
} }
if strategy == "quanto": if strategy == "quanto":
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
self.warn_in_place(model)
from optimum.quanto import quantize, qint8 # pylint: disable=import-error from optimum.quanto import quantize, qint8 # pylint: disable=import-error
exclusion_list = [ exclusion_list = [
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
] ]
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet _in_place_fixme = unet
elif strategy == "torchao": elif "torchao" in strategy:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight # pylint: disable=import-error from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
model = model.clone() model = model.clone()
self.warn_in_place(model)
unet = model.get_model_object("diffusion_model") unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# def filter_fn(module: torch.nn.Module, name: str): def filter(module: torch.nn.Module, fqn: str) -> bool:
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions) return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
_in_place_fixme = unet if "autoquant" in strategy:
_in_place_fixme = autoquant(unet, error_on_unseen=False)
else:
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device(), filter_fn=filter)
_in_place_fixme = unet
else: else:
raise ValueError(f"unknown strategy {strategy}") raise ValueError(f"unknown strategy {strategy}")

View File

View File

@ -0,0 +1,64 @@
import pytest
import torch
from comfy import model_management
from comfy.model_base import Flux
from comfy.model_patcher import ModelPatcher
from comfy.nodes.base_nodes import UNETLoader
from comfy_extras.nodes.nodes_torch_compile import QuantizeModel
has_torchao = True
try:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
except (ImportError, ModuleNotFoundError):
has_torchao = False
has_tensorrt = True
try:
from comfyui_tensorrt import STATIC_TRT_MODEL_CONVERSION
except (ImportError, ModuleNotFoundError):
has_tensorrt = False
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
async def test_unit_torchao(checkpoint_name):
# Downloads FLUX.1-dev and loads it using ComfyUI's models
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
transformer: Flux = model.get_model_object("diffusion_model")
quantize_(transformer, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
assert transformer is not None
del transformer
model_management.unload_all_models()
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(not has_torchao, reason="torchao not installed")
async def test_torchao_node(checkpoint_name, strategy):
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
transformer = quantized_model.get_model_object("diffusion_model")
del transformer
model_management.unload_all_models()
@pytest.mark.parametrize("checkpoint_name", ["flux1-dev.safetensors"])
@pytest.mark.parametrize("strategy", ["torchao", "torchao-autoquant"])
@pytest.mark.skipif(True, reason="not yet supported")
async def test_torchao_into_tensorrt(checkpoint_name, strategy):
model, = UNETLoader().load_unet(checkpoint_name, weight_dtype="default")
model: ModelPatcher = model.clone()
model_management.load_models_gpu([model], force_full_load=True)
model.diffusion_model = model.diffusion_model.to(memory_format=torch.channels_last)
model.diffusion_model = torch.compile(model.diffusion_model, mode="max-autotune", fullgraph=True)
quantized_model, = QuantizeModel().execute(model, strategy=strategy)
STATIC_TRT_MODEL_CONVERSION().convert(quantized_model, "test", 1, 1024, 1024, 1, 14)
model_management.unload_all_models()