From a8d8bff548df305abaa90d53af65e27b1daf27bc Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 29 Oct 2024 19:22:26 -0700 Subject: [PATCH] Improve support for torch compilation and sage attention --- comfy/ldm/flux/layers.py | 6 +- comfy/ldm/modules/attention.py | 15 +++-- comfy/model_patcher.py | 6 +- comfy/ops.py | 6 +- comfy_extras/nodes/nodes_torch_compile.py | 79 +++++++++++++++++------ requirements-triton.txt | 5 ++ requirements.txt | 3 +- setup.py | 2 + 8 files changed, 88 insertions(+), 34 deletions(-) create mode 100644 requirements-triton.txt diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 11b241e2d..3ff47c935 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -173,8 +173,8 @@ class DoubleStreamBlock(nn.Module): img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # calculate the txt bloks - txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -231,7 +231,7 @@ class SingleStreamBlock(nn.Module): attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x += mod.gate * output + x = x + mod.gate * output if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) return x diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a25421750..88c71346e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -28,6 +28,7 @@ from ... import ops ops = ops.disable_weight_init FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() +logger = logging.getLogger(__name__) def get_attn_precision(attn_precision): @@ -324,12 +325,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape model_management.soft_empty_cache(True) if cleared_cache == False: cleared_cache = True - logging.warning("out of memory error, emptying cache and trying again") + logger.warning("out of memory error, emptying cache and trying again") continue steps *= 2 if steps > 64: raise e - logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) + logger.warning("out of memory error, increasing steps and trying again {}".format(steps)) else: raise e @@ -432,20 +433,20 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re optimized_attention = attention_basic if model_management.sage_attention_enabled(): - logging.debug("Using sage attention") + logger.info("Using sage attention") optimized_attention = attention_sageattn elif model_management.xformers_enabled(): - logging.debug("Using xformers cross attention") + logger.info("Using xformers cross attention") optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): - logging.debug("Using pytorch cross attention") + logger.info("Using pytorch cross attention") optimized_attention = attention_pytorch else: if args.use_split_cross_attention: - logging.debug("Using split optimization for cross attention") + logger.info("Using split optimization for cross attention") optimized_attention = attention_split else: - logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + logger.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad optimized_attention_masked = optimized_attention diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a27a5f2a8..bdf818582 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -24,6 +24,7 @@ from typing import Optional import torch import torch.nn +from humanize import naturalsize from . import model_management, lora from . import utils @@ -600,10 +601,11 @@ class ModelPatcher(ModelManageable): return self.current_loaded_device() def __str__(self): + info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}" if self.ckpt_name is not None: - return f"" + return f"" else: - return f"" + return f"" def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") diff --git a/comfy/ops.py b/comfy/ops.py index 0855771b7..0ff1ed748 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,6 +18,7 @@ from typing import Optional import torch +from torch import Tensor from . import model_management from .cli_args import args @@ -92,7 +93,10 @@ class skip_init: pass class Embedding(SkipInit, torch.nn.Embedding): - pass + def forward(self, *args, **kwargs) -> Tensor: + if "out_dtype" in kwargs: + kwargs.pop("out_dtype") + return super().forward(*args, **kwargs) @classmethod def conv_nd(cls, dims, *args, **kwargs): diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index fce477be0..e6bcd545c 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -1,6 +1,10 @@ import logging +import os +from pathlib import Path +from typing import Union import torch +import torch._inductor.codecache from torch.nn import LayerNorm from comfy import model_management @@ -8,6 +12,35 @@ from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes DIFFUSION_MODEL = "diffusion_model" +TORCH_COMPILE_BACKENDS = [ + "inductor", + "torch_tensorrt", + "onnxrt", + "cudagraphs", + "openxla", + "tvm" +] + +TORCH_COMPILE_MODES = [ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs" +] + +# fix torch bug on windows +_old_write_atomic = torch._inductor.codecache.write_atomic + + +def write_atomic( + path: str, content: Union[str, bytes], make_dirs: bool = False +) -> None: + if Path(path).exists(): + os.remove(path) + _old_write_atomic(path, content, make_dirs=make_dirs) + + +torch._inductor.codecache.write_atomic = write_atomic class TorchCompileModel(CustomNode): @@ -21,40 +54,46 @@ class TorchCompileModel(CustomNode): "object_patch": ("STRING", {"default": DIFFUSION_MODEL}), "fullgraph": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}), - "backend": ("STRING", {"default": "inductor"}), + "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), + "mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}) } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - INFERENCE_MODE = False + # INFERENCE_MODE = False CATEGORY = "_for_testing" EXPERIMENTAL = True - def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]: + def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]: if object_patch is None: object_patch = DIFFUSION_MODEL compile_kwargs = { "fullgraph": fullgraph, "dynamic": dynamic, - "backend": backend + "backend": backend, + "mode": mode, } - if backend == "torch_tensorrt": - compile_kwargs["options"] = { - # https://pytorch.org/TensorRT/dynamo/torch_compile.html - # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. - "enabled_precisions": {torch.float, torch.half} - } - if isinstance(model, ModelPatcher): - m = model.clone() - m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) - return (m,) - elif isinstance(model, torch.nn.Module): - return torch.compile(model=model, **compile_kwargs), - else: - logging.warning("Encountered a model that cannot be compiled") - return model, + try: + if backend == "torch_tensorrt": + compile_kwargs["options"] = { + # https://pytorch.org/TensorRT/dynamo/torch_compile.html + # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. + "enabled_precisions": {torch.float, torch.half} + } + if isinstance(model, ModelPatcher): + m = model.clone() + m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) + return (m,) + elif isinstance(model, torch.nn.Module): + return torch.compile(model=model, **compile_kwargs), + else: + logging.warning("Encountered a model that cannot be compiled") + return model, + except OSError: + torch._inductor.utils.clear_inductor_caches() + raise _QUANTIZATION_STRATEGIES = [ @@ -77,7 +116,7 @@ class QuantizeModel(CustomNode): FUNCTION = "execute" CATEGORY = "_for_testing" EXPERIMENTAL = True - INFERENCE_MODE = False + # INFERENCE_MODE = False RETURN_TYPES = ("MODEL",) diff --git a/requirements-triton.txt b/requirements-triton.txt new file mode 100644 index 000000000..c95d88217 --- /dev/null +++ b/requirements-triton.txt @@ -0,0 +1,5 @@ +sageattention +triton ;platform_system == 'Linux' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11' +triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10' \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d31dee04f..1d848c940 100644 --- a/requirements.txt +++ b/requirements.txt @@ -67,4 +67,5 @@ vtracer skia-python pebble>=5.0.7 openai -anthropic \ No newline at end of file +anthropic +humanize \ No newline at end of file diff --git a/setup.py b/setup.py index b2e1e3518..972765fd7 100644 --- a/setup.py +++ b/setup.py @@ -191,6 +191,7 @@ package_data = [ '**/*' ] dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines() +triton_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-triton.txt")).readlines() setup( name=package_name, description="An installable version of ComfyUI", @@ -213,6 +214,7 @@ setup( extras_require={ 'withtorch': dependencies(install_torch_for_system=True), 'withtorchnightly': dependencies(install_torch_for_system=True, force_nightly=True), + 'withtriton': dependencies(install_torch_for_system=True) + triton_dependencies, 'dev': dev_dependencies }, )