mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Improve support for torch compilation and sage attention
This commit is contained in:
parent
ea5078f7f2
commit
a8d8bff548
@ -173,8 +173,8 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
@ -231,7 +231,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x = x + mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from ... import ops
|
|||||||
ops = ops.disable_weight_init
|
ops = ops.disable_weight_init
|
||||||
|
|
||||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_attn_precision(attn_precision):
|
def get_attn_precision(attn_precision):
|
||||||
@ -324,12 +325,12 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
|||||||
model_management.soft_empty_cache(True)
|
model_management.soft_empty_cache(True)
|
||||||
if cleared_cache == False:
|
if cleared_cache == False:
|
||||||
cleared_cache = True
|
cleared_cache = True
|
||||||
logging.warning("out of memory error, emptying cache and trying again")
|
logger.warning("out of memory error, emptying cache and trying again")
|
||||||
continue
|
continue
|
||||||
steps *= 2
|
steps *= 2
|
||||||
if steps > 64:
|
if steps > 64:
|
||||||
raise e
|
raise e
|
||||||
logging.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
logger.warning("out of memory error, increasing steps and trying again {}".format(steps))
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -432,20 +433,20 @@ def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_re
|
|||||||
optimized_attention = attention_basic
|
optimized_attention = attention_basic
|
||||||
|
|
||||||
if model_management.sage_attention_enabled():
|
if model_management.sage_attention_enabled():
|
||||||
logging.debug("Using sage attention")
|
logger.info("Using sage attention")
|
||||||
optimized_attention = attention_sageattn
|
optimized_attention = attention_sageattn
|
||||||
elif model_management.xformers_enabled():
|
elif model_management.xformers_enabled():
|
||||||
logging.debug("Using xformers cross attention")
|
logger.info("Using xformers cross attention")
|
||||||
optimized_attention = attention_xformers
|
optimized_attention = attention_xformers
|
||||||
elif model_management.pytorch_attention_enabled():
|
elif model_management.pytorch_attention_enabled():
|
||||||
logging.debug("Using pytorch cross attention")
|
logger.info("Using pytorch cross attention")
|
||||||
optimized_attention = attention_pytorch
|
optimized_attention = attention_pytorch
|
||||||
else:
|
else:
|
||||||
if args.use_split_cross_attention:
|
if args.use_split_cross_attention:
|
||||||
logging.debug("Using split optimization for cross attention")
|
logger.info("Using split optimization for cross attention")
|
||||||
optimized_attention = attention_split
|
optimized_attention = attention_split
|
||||||
else:
|
else:
|
||||||
logging.debug("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
logger.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||||
optimized_attention = attention_sub_quad
|
optimized_attention = attention_sub_quad
|
||||||
|
|
||||||
optimized_attention_masked = optimized_attention
|
optimized_attention_masked = optimized_attention
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
from humanize import naturalsize
|
||||||
|
|
||||||
from . import model_management, lora
|
from . import model_management, lora
|
||||||
from . import utils
|
from . import utils
|
||||||
@ -600,10 +601,11 @@ class ModelPatcher(ModelManageable):
|
|||||||
return self.current_loaded_device()
|
return self.current_loaded_device()
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
|
info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}"
|
||||||
if self.ckpt_name is not None:
|
if self.ckpt_name is not None:
|
||||||
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__} {info_str})>"
|
||||||
else:
|
else:
|
||||||
return f"<ModelPatcher for {self.model.__class__.__name__}>"
|
return f"<ModelPatcher for {self.model.__class__.__name__} ({info_str})>"
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
@ -92,7 +93,10 @@ class skip_init:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class Embedding(SkipInit, torch.nn.Embedding):
|
class Embedding(SkipInit, torch.nn.Embedding):
|
||||||
pass
|
def forward(self, *args, **kwargs) -> Tensor:
|
||||||
|
if "out_dtype" in kwargs:
|
||||||
|
kwargs.pop("out_dtype")
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def conv_nd(cls, dims, *args, **kwargs):
|
def conv_nd(cls, dims, *args, **kwargs):
|
||||||
|
|||||||
@ -1,6 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._inductor.codecache
|
||||||
from torch.nn import LayerNorm
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
@ -8,6 +12,35 @@ from comfy.model_patcher import ModelPatcher
|
|||||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||||
|
|
||||||
DIFFUSION_MODEL = "diffusion_model"
|
DIFFUSION_MODEL = "diffusion_model"
|
||||||
|
TORCH_COMPILE_BACKENDS = [
|
||||||
|
"inductor",
|
||||||
|
"torch_tensorrt",
|
||||||
|
"onnxrt",
|
||||||
|
"cudagraphs",
|
||||||
|
"openxla",
|
||||||
|
"tvm"
|
||||||
|
]
|
||||||
|
|
||||||
|
TORCH_COMPILE_MODES = [
|
||||||
|
"default",
|
||||||
|
"reduce-overhead",
|
||||||
|
"max-autotune",
|
||||||
|
"max-autotune-no-cudagraphs"
|
||||||
|
]
|
||||||
|
|
||||||
|
# fix torch bug on windows
|
||||||
|
_old_write_atomic = torch._inductor.codecache.write_atomic
|
||||||
|
|
||||||
|
|
||||||
|
def write_atomic(
|
||||||
|
path: str, content: Union[str, bytes], make_dirs: bool = False
|
||||||
|
) -> None:
|
||||||
|
if Path(path).exists():
|
||||||
|
os.remove(path)
|
||||||
|
_old_write_atomic(path, content, make_dirs=make_dirs)
|
||||||
|
|
||||||
|
|
||||||
|
torch._inductor.codecache.write_atomic = write_atomic
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileModel(CustomNode):
|
class TorchCompileModel(CustomNode):
|
||||||
@ -21,40 +54,46 @@ class TorchCompileModel(CustomNode):
|
|||||||
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
"object_patch": ("STRING", {"default": DIFFUSION_MODEL}),
|
||||||
"fullgraph": ("BOOLEAN", {"default": False}),
|
"fullgraph": ("BOOLEAN", {"default": False}),
|
||||||
"dynamic": ("BOOLEAN", {"default": False}),
|
"dynamic": ("BOOLEAN", {"default": False}),
|
||||||
"backend": ("STRING", {"default": "inductor"}),
|
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
|
||||||
|
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "patch"
|
FUNCTION = "patch"
|
||||||
INFERENCE_MODE = False
|
# INFERENCE_MODE = False
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
|
|
||||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor") -> tuple[ModelPatcher]:
|
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]:
|
||||||
if object_patch is None:
|
if object_patch is None:
|
||||||
object_patch = DIFFUSION_MODEL
|
object_patch = DIFFUSION_MODEL
|
||||||
compile_kwargs = {
|
compile_kwargs = {
|
||||||
"fullgraph": fullgraph,
|
"fullgraph": fullgraph,
|
||||||
"dynamic": dynamic,
|
"dynamic": dynamic,
|
||||||
"backend": backend
|
"backend": backend,
|
||||||
|
"mode": mode,
|
||||||
}
|
}
|
||||||
if backend == "torch_tensorrt":
|
try:
|
||||||
compile_kwargs["options"] = {
|
if backend == "torch_tensorrt":
|
||||||
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
|
compile_kwargs["options"] = {
|
||||||
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
|
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
|
||||||
"enabled_precisions": {torch.float, torch.half}
|
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
|
||||||
}
|
"enabled_precisions": {torch.float, torch.half}
|
||||||
if isinstance(model, ModelPatcher):
|
}
|
||||||
m = model.clone()
|
if isinstance(model, ModelPatcher):
|
||||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
m = model.clone()
|
||||||
return (m,)
|
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||||
elif isinstance(model, torch.nn.Module):
|
return (m,)
|
||||||
return torch.compile(model=model, **compile_kwargs),
|
elif isinstance(model, torch.nn.Module):
|
||||||
else:
|
return torch.compile(model=model, **compile_kwargs),
|
||||||
logging.warning("Encountered a model that cannot be compiled")
|
else:
|
||||||
return model,
|
logging.warning("Encountered a model that cannot be compiled")
|
||||||
|
return model,
|
||||||
|
except OSError:
|
||||||
|
torch._inductor.utils.clear_inductor_caches()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
_QUANTIZATION_STRATEGIES = [
|
_QUANTIZATION_STRATEGIES = [
|
||||||
@ -77,7 +116,7 @@ class QuantizeModel(CustomNode):
|
|||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
INFERENCE_MODE = False
|
# INFERENCE_MODE = False
|
||||||
|
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
|
||||||
|
|||||||
5
requirements-triton.txt
Normal file
5
requirements-triton.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
sageattention
|
||||||
|
triton ;platform_system == 'Linux'
|
||||||
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp312-cp312-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.12'
|
||||||
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp311-cp311-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.11'
|
||||||
|
triton @ https://github.com/woct0rdho/triton-windows/releases/download/v3.1.0-windows.post5/triton-3.1.0-cp310-cp310-win_amd64.whl ;platform_system == 'Windows' and python_version == '3.10'
|
||||||
@ -68,3 +68,4 @@ skia-python
|
|||||||
pebble>=5.0.7
|
pebble>=5.0.7
|
||||||
openai
|
openai
|
||||||
anthropic
|
anthropic
|
||||||
|
humanize
|
||||||
2
setup.py
2
setup.py
@ -191,6 +191,7 @@ package_data = [
|
|||||||
'**/*'
|
'**/*'
|
||||||
]
|
]
|
||||||
dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines()
|
dev_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-dev.txt")).readlines()
|
||||||
|
triton_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements-triton.txt")).readlines()
|
||||||
setup(
|
setup(
|
||||||
name=package_name,
|
name=package_name,
|
||||||
description="An installable version of ComfyUI",
|
description="An installable version of ComfyUI",
|
||||||
@ -213,6 +214,7 @@ setup(
|
|||||||
extras_require={
|
extras_require={
|
||||||
'withtorch': dependencies(install_torch_for_system=True),
|
'withtorch': dependencies(install_torch_for_system=True),
|
||||||
'withtorchnightly': dependencies(install_torch_for_system=True, force_nightly=True),
|
'withtorchnightly': dependencies(install_torch_for_system=True, force_nightly=True),
|
||||||
|
'withtriton': dependencies(install_torch_for_system=True) + triton_dependencies,
|
||||||
'dev': dev_dependencies
|
'dev': dev_dependencies
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user