mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
115 lines
4.0 KiB
Python
115 lines
4.0 KiB
Python
import logging
|
|
|
|
import torch
|
|
from torch.nn import LayerNorm
|
|
|
|
from comfy import model_management
|
|
from comfy.model_patcher import ModelPatcher
|
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
|
|
|
DIFFUSION_MODEL = "diffusion_model"
|
|
|
|
|
|
class TorchCompileModel:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL",),
|
|
},
|
|
"optional": {
|
|
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
|
"fullgraph": ("BOOLEAN", {"default": False}),
|
|
"dynamic": ("BOOLEAN", {"default": False}),
|
|
"backend": ("STRING", {"default": "inductor"}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
FUNCTION = "patch"
|
|
|
|
CATEGORY = "_for_testing"
|
|
EXPERIMENTAL = True
|
|
|
|
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor"):
|
|
if object_patch is None:
|
|
object_patch = DIFFUSION_MODEL
|
|
compile_kwargs = {
|
|
"fullgraph": fullgraph,
|
|
"dynamic": dynamic,
|
|
"backend": backend
|
|
}
|
|
if isinstance(model, ModelPatcher):
|
|
m = model.clone()
|
|
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
|
return (m,)
|
|
elif isinstance(model, torch.nn.Module):
|
|
return torch.compile(model=model, **compile_kwargs),
|
|
else:
|
|
logging.warning("Encountered a model that cannot be compiled")
|
|
return model,
|
|
|
|
|
|
class QuantizeModel(CustomNode):
|
|
@classmethod
|
|
def INPUT_TYPES(cls) -> InputTypes:
|
|
return {
|
|
"required": {
|
|
"model": ("MODEL", {}),
|
|
"strategy": (["torchao", "quanto"], {"default": "torchao"})
|
|
}
|
|
}
|
|
|
|
FUNCTION = "execute"
|
|
CATEGORY = "_for_testing"
|
|
EXPERIMENTAL = True
|
|
|
|
RETURN_TYPES = ("MODEL",)
|
|
|
|
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
|
|
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
|
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
|
model = model.clone()
|
|
unet = model.get_model_object("diffusion_model")
|
|
# todo: quantize quantizes in place, which is not desired
|
|
|
|
# default exclusions
|
|
_unused_exclusions = {
|
|
"time_embedding.",
|
|
"add_embedding.",
|
|
"time_in.",
|
|
"txt_in.",
|
|
"vector_in.",
|
|
"img_in.",
|
|
"guidance_in.",
|
|
"final_layer.",
|
|
}
|
|
if strategy == "quanto":
|
|
from optimum.quanto import quantize, qint8
|
|
exclusion_list = [
|
|
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
|
]
|
|
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
|
_in_place_fixme = unet
|
|
elif strategy == "torchao":
|
|
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
|
model = model.clone()
|
|
unet = model.get_model_object("diffusion_model")
|
|
# todo: quantize quantizes in place, which is not desired
|
|
|
|
# def filter_fn(module: torch.nn.Module, name: str):
|
|
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
|
|
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
|
_in_place_fixme = unet
|
|
else:
|
|
raise ValueError(f"unknown strategy {strategy}")
|
|
|
|
model.add_object_patch("diffusion_model", _in_place_fixme)
|
|
return model,
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"TorchCompileModel": TorchCompileModel,
|
|
"QuantizeModel": QuantizeModel,
|
|
}
|