Improve support for torch compilation and sage attention

This commit is contained in:
doctorpangloss 2024-10-29 19:22:26 -07:00
parent ea5078f7f2
commit a8d8bff548
8 changed files with 88 additions and 34 deletions

View File

@ -173,8 +173,8 @@ class DoubleStreamBlock(nn.Module):
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks # calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@ -231,7 +231,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output x = x + mod.gate * output
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x return x

View File

@ -28,6 +28,7 @@ from ... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
logger = logging.getLogger(__name__)
def get_attn_precision(attn_precision): def get_attn_precision(attn_precision):
@ -324,12 +325,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:
cleared_cache = True cleared_cache = True
logging.warning("out of memory error, emptying cache and trying again") logger.warning("out of memory error, emptying cache and trying again")
continue continue
steps *= 2 steps *= 2
if steps > 64: if steps > 64:
raise e raise e
logging.warning("out of memory error, increasing steps and trying again {}".format(steps)) logger.warning("out of memory error, increasing steps and trying again {}".format(steps))
else: else:
raise e raise e
@ -432,20 +433,20 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
logging.debug("Using sage attention") logger.info("Using sage attention")
optimized_attention = attention_sageattn optimized_attention = attention_sageattn
elif model_management.xformers_enabled(): elif model_management.xformers_enabled():
logging.debug("Using xformers cross attention") logger.info("Using xformers cross attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.debug("Using pytorch cross attention") logger.info("Using pytorch cross attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch
else: else:
if args.use_split_cross_attention: if args.use_split_cross_attention:
logging.debug("Using split optimization for cross attention") logger.info("Using split optimization for cross attention")
optimized_attention = attention_split optimized_attention = attention_split
else: else:
logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") logger.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention

View File

@ -24,6 +24,7 @@ from typing import Optional
import torch import torch
import torch.nn import torch.nn
from humanize import naturalsize
from . import model_management, lora from . import model_management, lora
from . import utils from . import utils
@ -600,10 +601,11 @@ class ModelPatcher(ModelManageable):
return self.current_loaded_device() return self.current_loaded_device()
def __str__(self): def __str__(self):
info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}"
if self.ckpt_name is not None: if self.ckpt_name is not None:
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>" return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__} {info_str})>"
else: else:
return f"<ModelPatcher for {self.model.__class__.__name__}>" return f"<ModelPatcher for {self.model.__class__.__name__} ({info_str})>"
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")

View File

@ -18,6 +18,7 @@
from typing import Optional from typing import Optional
import torch import torch
from torch import Tensor
from . import model_management from . import model_management
from .cli_args import args from .cli_args import args
@ -92,7 +93,10 @@ class skip_init:
pass pass
class Embedding(SkipInit, torch.nn.Embedding): class Embedding(SkipInit, torch.nn.Embedding):
pass def forward(self, *args, **kwargs) -> Tensor:
if "out_dtype" in kwargs:
kwargs.pop("out_dtype")
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(cls, dims, *args, **kwargs): def conv_nd(cls, dims, *args, **kwargs):

View File

@ -1,6 +1,10 @@
import logging import logging
import os
from pathlib import Path
from typing import Union
import torch import torch
import torch._inductor.codecache
from torch.nn import LayerNorm from torch.nn import LayerNorm
from comfy import model_management from comfy import model_management
@ -8,6 +12,35 @@ from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model" DIFFUSION_MODEL = "diffusion_model"
TORCH_COMPILE_BACKENDS = [
"inductor",
"torch_tensorrt",
"onnxrt",
"cudagraphs",
"openxla",
"tvm"
]
TORCH_COMPILE_MODES = [
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs"
]
# fix torch bug on windows
_old_write_atomic = torch._inductor.codecache.write_atomic
def write_atomic(
path: str, content: Union[str, bytes], make_dirs: bool = False
) -> None:
if Path(path).exists():
os.remove(path)
_old_write_atomic(path, content, make_dirs=make_dirs)
torch._inductor.codecache.write_atomic = write_atomic
class TorchCompileModel(CustomNode): class TorchCompileModel(CustomNode):
@ -21,40 +54,46 @@ class TorchCompileModel(CustomNode):
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}), "object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
"fullgraph": ("BOOLEAN", {"default": False}), "fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}),
"backend": ("STRING", {"default": "inductor"}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"})
} }
} }
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
INFERENCE_MODE = False # INFERENCE_MODE = False
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]: def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]:
if object_patch is None: if object_patch is None:
object_patch = DIFFUSION_MODEL object_patch = DIFFUSION_MODEL
compile_kwargs = { compile_kwargs = {
"fullgraph": fullgraph, "fullgraph": fullgraph,
"dynamic": dynamic, "dynamic": dynamic,
"backend": backend "backend": backend,
"mode": mode,
} }
if backend == "torch_tensorrt": try:
compile_kwargs["options"] = { if backend == "torch_tensorrt":
# https://pytorch.org/TensorRT/dynamo/torch_compile.html compile_kwargs["options"] = {
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. # https://pytorch.org/TensorRT/dynamo/torch_compile.html
"enabled_precisions": {torch.float, torch.half} # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
} "enabled_precisions": {torch.float, torch.half}
if isinstance(model, ModelPatcher): }
m = model.clone() if isinstance(model, ModelPatcher):
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) m = model.clone()
return (m,) m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
elif isinstance(model, torch.nn.Module): return (m,)
return torch.compile(model=model, **compile_kwargs), elif isinstance(model, torch.nn.Module):
else: return torch.compile(model=model, **compile_kwargs),
logging.warning("Encountered a model that cannot be compiled") else:
return model, logging.warning("Encountered a model that cannot be compiled")
return model,
except OSError:
torch._inductor.utils.clear_inductor_caches()
raise
_QUANTIZATION_STRATEGIES = [ _QUANTIZATION_STRATEGIES = [
@ -77,7 +116,7 @@ class QuantizeModel(CustomNode):
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
EXPERIMENTAL = True EXPERIMENTAL = True
INFERENCE_MODE = False # INFERENCE_MODE = False
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)

5
requirements-triton.txt Normal file
View File

@ -0,0 +1,5 @@
sageattention
triton ;platform_system == 'Linux'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'

View File

@ -68,3 +68,4 @@ skia-python
pebble>=5.0.7 pebble>=5.0.7
openai openai
anthropic anthropic
humanize

View File

@ -191,6 +191,7 @@ package_data = [
'**/*' '**/*'
] ]
dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines() dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines()
triton_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-triton.txt")).readlines()
setup( setup(
name=package_name, name=package_name,
description="An installable version of ComfyUI", description="An installable version of ComfyUI",
@ -213,6 +214,7 @@ setup(
extras_require={ extras_require={
'withtorch': dependencies(install_torch_for_system=True), 'withtorch': dependencies(install_torch_for_system=True),
'withtorchnightly': dependencies(install_torch_for_system=True, force_nightly=True), 'withtorchnightly': dependencies(install_torch_for_system=True, force_nightly=True),
'withtriton': dependencies(install_torch_for_system=True) + triton_dependencies,
'dev': dev_dependencies 'dev': dev_dependencies
}, },
) )