Memory management and compilation improvements

- Experimental support for sage attention on Linux
 - Diffusers loader now supports model indices
 - Transformers model management now aligns with updates to ComfyUI
 - Flux layers correctly use unbind
 - Add float8 support for model loading in more places
 - Experimental quantization approaches from Quanto and torchao
 - Model upscaling interacts with memory management better

This update also disables ROCm testing because it isn't reliable enough
on consumer hardware. ROCm is not really supported by the 7600.
This commit is contained in:
doctorpangloss 2024-10-09 09:11:22 -07:00
parent 0a25b67ff8
commit bbe2ed330c
16 changed files with 319 additions and 142 deletions

View File

@ -52,8 +52,6 @@ jobs:
runner: runner:
- labels: [self-hosted, Linux, X64, cpu] - labels: [self-hosted, Linux, X64, cpu]
container: "ubuntu:22.04" container: "ubuntu:22.04"
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
container: "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0"
- labels: [self-hosted, Linux, X64, cuda-3060-12gb] - labels: [self-hosted, Linux, X64, cuda-3060-12gb]
container: "nvcr.io/nvidia/pytorch:24.03-py3" container: "nvcr.io/nvidia/pytorch:24.03-py3"
steps: steps:

View File

@ -99,6 +99,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Use the new pytorch 2.0 cross attention function.") help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention")
parser.add_argument("--disable-sage-attention", action="store_true", help="Disable Sage Attention")
upcast = parser.add_mutually_exclusive_group() upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")

View File

@ -75,6 +75,8 @@ class Configuration(dict):
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
disable_xformers (bool): Disable xformers. disable_xformers (bool): Disable xformers.
disable_flash_attn (bool): Disable flash_attn package attention.
disable_sage_attention (bool): Disable sage attention package attention.
gpu_only (bool): Run everything on the GPU. gpu_only (bool): Run everything on the GPU.
highvram (bool): Keep models in GPU memory. highvram (bool): Keep models in GPU memory.
normalvram (bool): Default VRAM usage setting. normalvram (bool): Default VRAM usage setting.
@ -157,6 +159,8 @@ class Configuration(dict):
self.use_quad_cross_attention: bool = False self.use_quad_cross_attention: bool = False
self.use_pytorch_cross_attention: bool = False self.use_pytorch_cross_attention: bool = False
self.disable_xformers: bool = False self.disable_xformers: bool = False
self.disable_flash_attn: bool = False
self.disable_sage_attention: bool = False
self.gpu_only: bool = False self.gpu_only: bool = False
self.highvram: bool = False self.highvram: bool = False
self.normalvram: bool = False self.normalvram: bool = False

View File

@ -13,9 +13,17 @@ def first_file(path, filenames) -> str | None:
return None return None
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None): def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None, model_options=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"] if model_options is None:
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) model_options = {}
diffusion_model_names = [
"diffusion_pytorch_model.fp16.safetensors",
"diffusion_pytorch_model.safetensors",
"diffusion_pytorch_model.fp16.bin",
"diffusion_pytorch_model.bin",
"diffusion_pytorch_model.safetensors.index.json"
]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) or first_file(os.path.join(model_path, "transformer"), diffusion_model_names)
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names) vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"] text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
@ -28,7 +36,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
unet = None unet = None
if unet_path is not None: if unet_path is not None:
unet = sd.load_diffusion_model(unet_path) unet = sd.load_diffusion_model(unet_path, model_options=model_options)
clip = None clip = None
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"]) textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])

View File

@ -79,13 +79,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
# if we have flash-attn installed, try to use it # if we have flash-attn installed, try to use it
try: try:
import flash_attn if model_management.flash_attn_enabled():
attn_override_kwargs = { attn_override_kwargs = {
"attn_implementation": "flash_attention_2", "attn_implementation": "flash_attention_2",
**kwargs_to_try[0] **kwargs_to_try[0]
} }
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError: except ImportError:
pass pass
for i, props in enumerate(kwargs_to_try): for i, props in enumerate(kwargs_to_try):
@ -303,16 +303,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
return self.model.dtype return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs") def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
return self.model.to(device=device_to) return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module: def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=device_to) return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=offload_device)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel: def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
model = copy.copy(self) model = copy.copy(self)
model._processor = processor model._processor = processor

View File

@ -149,14 +149,16 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k, img_v = torch.unbind(img_qkv, dim=0)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
@ -221,7 +223,8 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = torch.unbind(qkv, dim=0)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention

View File

@ -1,10 +1,12 @@
import logging
import math import math
from functools import wraps
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat from einops import rearrange, repeat
from typing import Optional from torch import nn, einsum
import logging
from .diffusionmodules.util import AlphaBlender, timestep_embedding from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention from .sub_quadratic_attention import efficient_dot_product_attention
@ -12,14 +14,22 @@ from ... import model_management
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers # pylint: disable=import-error import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error import xformers.ops # pylint: disable=import-error
if model_management.sage_attention_enabled():
from sageattention import sageattn
if model_management.flash_attn_enabled():
from flash_attn import flash_attn_func
from ...cli_args import args from ...cli_args import args
from ... import ops from ... import ops
ops = ops.disable_weight_init ops = ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision): def get_attn_precision(attn_precision):
if args.dont_upcast_attention: if args.dont_upcast_attention:
return None return None
@ -27,12 +37,13 @@ def get_attn_precision(attn_precision):
return FORCE_UPCAST_ATTENTION_DTYPE return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision return attn_precision
def exists(val): def exists(val):
return val is not None return val is not None
def uniq(arr): def uniq(arr):
return{el: True for el in arr}.keys() return {el: True for el in arr}.keys()
def default(val, d): def default(val, d):
@ -82,9 +93,11 @@ class FeedForward(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
def Normalize(in_channels, dtype=None, device=None): def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
@ -98,7 +111,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
h = heads h = heads
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v), (q, k, v),
) )
@ -122,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if exists(mask): if exists(mask):
if mask.dtype == torch.bool: if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention mask = rearrange(mask, 'b ... -> b (...)') # TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h) mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value) sim.masked_fill_(~mask, max_neg_value)
@ -167,13 +180,12 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype dtype = query.dtype
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention: if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8 bytes_per_token = torch.finfo(torch.float32).bits // 8
else: else:
bytes_per_token = torch.finfo(query.dtype).bits//8 bytes_per_token = torch.finfo(query.dtype).bits // 8
batch_x_heads, q_tokens, _ = query.shape batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape _, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
@ -215,9 +227,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.to(dtype) hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2)
return hidden_states return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision) attn_precision = get_attn_precision(attn_precision)
@ -231,7 +244,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
h = heads h = heads
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v), (q, k, v),
) )
@ -262,16 +275,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
steps = 1 steps = 1
if mem_required > mem_free_total: if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64: if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
if mask is not None: if mask is not None:
if len(mask.shape) == 2: if len(mask.shape) == 2:
@ -289,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
if upcast: if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'): with torch.autocast(enabled=False, device_type='cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else: else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
@ -331,11 +343,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
) )
return r1 return r1
BROKEN_XFORMERS = False
if model_management.xformers_enabled():
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape: if skip_reshape:
@ -346,10 +353,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = False disabled_xformers = False
if BROKEN_XFORMERS:
if b * heads > 65535:
disabled_xformers = True
if not disabled_xformers: if not disabled_xformers:
if torch.jit.is_tracing() or torch.jit.is_scripting(): if torch.jit.is_tracing() or torch.jit.is_scripting():
disabled_xformers = True disabled_xformers = True
@ -358,7 +361,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v), (q, k, v),
) )
@ -390,22 +393,36 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def pytorch_style_decl(func):
if skip_reshape: @wraps(func)
b, _, _, dim_head = q.shape def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
else: if skip_reshape:
b, _, dim_head = q.shape b, _, _, dim_head = q.shape
dim_head //= heads else:
q, k, v = map( b, _, dim_head = q.shape
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), dim_head //= heads
(q, k, v), q, k, v = map(
) lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = func(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
out = ( out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
out.transpose(1, 2).reshape(b, -1, heads * dim_head) return out
)
return out return wrapper
@pytorch_style_decl
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@pytorch_style_decl
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@pytorch_style_decl
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return flash_attn_func(q, k, v)
optimized_attention = attention_basic optimized_attention = attention_basic
@ -426,10 +443,11 @@ else:
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention
def optimized_attention_for_device(device, mask=False, small_input=False): def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input: if small_input:
if model_management.pytorch_attention_enabled(): if model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases return attention_pytorch # TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else: else:
return attention_basic return attention_basic
@ -493,7 +511,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention: if disable_temporal_crossattention:
@ -507,7 +525,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -641,6 +659,7 @@ class SpatialTransformer(nn.Module):
Finally, reshape to image Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs NEW: use_linear for more efficiency instead of the 1x1 convs
""" """
def __init__(self, in_channels, n_heads, d_head, def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None, depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False, disable_self_attn=False, use_linear=False,
@ -653,23 +672,23 @@ class SpatialTransformer(nn.Module):
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear: if not use_linear:
self.proj_in = operations.Conv2d(in_channels, self.proj_in = operations.Conv2d(in_channels,
inner_dim, inner_dim,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList( self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)] for d in range(depth)]
) )
if not use_linear: if not use_linear:
self.proj_out = operations.Conv2d(inner_dim,in_channels, self.proj_out = operations.Conv2d(inner_dim, in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, dtype=dtype, device=device) padding=0, dtype=dtype, device=device)
else: else:
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear self.use_linear = use_linear
@ -699,27 +718,27 @@ class SpatialTransformer(nn.Module):
class SpatialVideoTransformer(SpatialTransformer): class SpatialVideoTransformer(SpatialTransformer):
def __init__( def __init__(
self, self,
in_channels, in_channels,
n_heads, n_heads,
d_head, d_head,
depth=1, depth=1,
dropout=0.0, dropout=0.0,
use_linear=False, use_linear=False,
context_dim=None, context_dim=None,
use_spatial_context=False, use_spatial_context=False,
timesteps=None, timesteps=None,
merge_strategy: str = "fixed", merge_strategy: str = "fixed",
merge_factor: float = 0.5, merge_factor: float = 0.5,
time_context_dim=None, time_context_dim=None,
ff_in=False, ff_in=False,
checkpoint=False, checkpoint=False,
time_depth=1, time_depth=1,
disable_self_attn=False, disable_self_attn=False,
disable_temporal_crossattention=False, disable_temporal_crossattention=False,
max_time_embed_period: int = 10000, max_time_embed_period: int = 10000,
attn_precision=None, attn_precision=None,
dtype=None, device=None, operations=ops dtype=None, device=None, operations=ops
): ):
super().__init__( super().__init__(
in_channels, in_channels,
@ -785,13 +804,13 @@ class SpatialVideoTransformer(SpatialTransformer):
) )
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
context: Optional[torch.Tensor] = None, context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None, time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None, timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={} transformer_options={}
) -> torch.Tensor: ) -> torch.Tensor:
_, _, h, w = x.shape _, _, h, w = x.shape
x_in = x x_in = x
@ -801,7 +820,7 @@ class SpatialVideoTransformer(SpatialTransformer):
if self.use_spatial_context: if self.use_spatial_context:
assert ( assert (
context.ndim == 3 context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}" ), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None: if time_context is None:
@ -830,7 +849,7 @@ class SpatialVideoTransformer(SpatialTransformer):
emb = emb[:, None, :] emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate( for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack) zip(self.transformer_blocks, self.time_stack)
): ):
transformer_options["block_index"] = it_ transformer_options["block_index"] = it_
x = block( x = block(
@ -844,7 +863,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options x_mix = mix_block(x_mix, context=time_context) # TODO: transformer_options
x_mix = rearrange( x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
) )
@ -858,5 +877,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x) x = self.proj_out(x)
out = x + x_in out = x + x_in
return out return out

View File

@ -108,7 +108,6 @@ class BaseModel(torch.nn.Module):
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last(): if model_management.force_channels_last():
# todo: ???
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model") logging.debug("using channels last mode for diffusion model")
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))

View File

@ -535,10 +535,6 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
for user_dir in Path(local_dir_root).iterdir(): for user_dir in Path(local_dir_root).iterdir():
for model_dir in user_dir.iterdir(): for model_dir in user_dir.iterdir():
try:
_hf_fs.resolve_path(str(user_dir / model_dir))
except Exception as exc_info:
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}") existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS) known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)

View File

@ -23,7 +23,7 @@ import sys
import warnings import warnings
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
from typing import Literal, List, Sequence from typing import Literal, List, Sequence, Final
import psutil import psutil
import torch import torch
@ -128,6 +128,9 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device()) return torch.device("xpu", torch.xpu.current_device())
else: else:
try: try:
# https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
except: except:
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device") warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
@ -319,7 +322,7 @@ try:
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = [] current_loaded_models: Final[List["LoadedModel"]] = []
def module_size(module): def module_size(module):
@ -974,6 +977,22 @@ def cast_to_device(tensor, device, dtype, copy=False):
else: else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
FLASH_ATTENTION_ENABLED = False
if not args.disable_flash_attn:
try:
import flash_attn
FLASH_ATTENTION_ENABLED = True
except ImportError:
pass
SAGE_ATTENTION_ENABLED = False
if not args.disable_sage_attention:
try:
import sageattention
SAGE_ATTENTION_ENABLED = True
except ImportError:
pass
def xformers_enabled(): def xformers_enabled():
global directml_device global directml_device
@ -986,6 +1005,30 @@ def xformers_enabled():
return False return False
return XFORMERS_IS_AVAILABLE return XFORMERS_IS_AVAILABLE
def flash_attn_enabled():
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_device:
return False
return FLASH_ATTENTION_ENABLED
def sage_attention_enabled():
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_device:
return False
if xformers_enabled():
return False
return SAGE_ATTENTION_ENABLED
def xformers_enabled_vae(): def xformers_enabled_vae():
enabled = xformers_enabled() enabled = xformers_enabled()

View File

@ -55,17 +55,13 @@ class ModelManageable(Protocol):
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
self.patch_model(device_to=device_to, patch_weights=False)
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
... ...
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
""" """
Unloads the model by moving it to the offload device Unloads the model by moving it to the offload device
:param offload_device: :param device_to:
:param unpatch_weights: :param unpatch_weights:
:return: :return:
""" """
@ -99,6 +95,20 @@ class ModelManageable(Protocol):
def current_loaded_device(self) -> torch.device: def current_loaded_device(self) -> torch.device:
return self.current_device return self.current_device
def get_model_object(self, name: str) -> torch.nn.Module:
from . import utils
return utils.get_attr(self.model, name)
@property
def model_options(self) -> dict:
if not hasattr(self, "_model_options"):
setattr(self, "_model_options", {"transformer_options": {}})
return getattr(self, "_model_options")
@model_options.setter
def model_options(self, value):
setattr(self, "_model_options", value)
@dataclasses.dataclass @dataclasses.dataclass
class MemoryMeasurements: class MemoryMeasurements:

View File

@ -27,10 +27,11 @@ import torch.nn
from . import model_management, lora from . import model_management, lora
from . import utils from . import utils
from .comfy_types import UnetWrapperFunction
from .float import stochastic_rounding from .float import stochastic_rounding
from .model_base import BaseModel from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements from .model_management_types import ModelManageable, MemoryMeasurements
from .comfy_types import UnetWrapperFunction
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@ -45,6 +46,7 @@ def string_to_seed(data):
crc >>= 1 crc >>= 1
return crc ^ 0xFFFFFFFF return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy() to = model_options["transformer_options"].copy()
@ -106,7 +108,7 @@ class ModelPatcher(ModelManageable):
self.backup = {} self.backup = {}
self.object_patches = {} self.object_patches = {}
self.object_patches_backup = {} self.object_patches_backup = {}
self.model_options = {"transformer_options": {}} self._model_options = {"transformer_options": {}}
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
@ -115,6 +117,14 @@ class ModelPatcher(ModelManageable):
self.ckpt_name = ckpt_name self.ckpt_name = ckpt_name
self._memory_measurements = MemoryMeasurements(self.model) self._memory_measurements = MemoryMeasurements(self.model)
@property
def model_options(self) -> dict:
return self._model_options
@model_options.setter
def model_options(self, value):
self._model_options = value
@property @property
def model_device(self) -> torch.device: def model_device(self) -> torch.device:
return self._memory_measurements.device return self._memory_measurements.device
@ -145,7 +155,7 @@ class ModelPatcher(ModelManageable):
n.patches_uuid = self.patches_uuid n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy() n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options) n._model_options = copy.deepcopy(self.model_options)
n.backup = self.backup n.backup = self.backup
n.object_patches_backup = self.object_patches_backup n.object_patches_backup = self.object_patches_backup
return n return n
@ -260,6 +270,11 @@ class ModelPatcher(ModelManageable):
self.model_options["model_function_wrapper"] = wrap_func.to(device) self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self): def model_dtype(self):
# this pokes into the internals of diffusion model a little bit
# todo: the base model isn't going to be aware that its diffusion model is patched this way
if isinstance(self.model, BaseModel):
diffusion_model = self.get_model_object("diffusion_model")
return diffusion_model.dtype
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
return self.model.get_dtype() return self.model.get_dtype()
@ -293,7 +308,7 @@ class ModelPatcher(ModelManageable):
if filter_prefix is not None: if filter_prefix is not None:
if not k.startswith(filter_prefix): if not k.startswith(filter_prefix):
continue continue
bk = self.backup.get(k, None) bk: torch.nn.Module | None = self.backup.get(k, None)
if bk is not None: if bk is not None:
weight = bk.weight weight = bk.weight
else: else:
@ -494,7 +509,7 @@ class ModelPatcher(ModelManageable):
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]: for key in [weight_key, bias_key]:
bk = self.backup.get(key, None) bk: torch.nn.Module | None = self.backup.get(key, None)
if bk is not None: if bk is not None:
if bk.inplace_update: if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight) utils.copy_to_param(self.model, key, bk.weight)

View File

@ -538,14 +538,16 @@ class DiffusersLoader:
paths += get_huggingface_repo_list() paths += get_huggingface_repo_list()
paths = list(frozenset(paths)) paths = list(frozenset(paths))
return {"required": {"model_path": (paths,), }} return {"required": {"model_path": (paths,),
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True): def load_checkpoint(self, model_path, output_vae=True, output_clip=True,weight_dtype:str="default"):
for search_path in folder_paths.get_folder_paths("diffusers"): for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path): if os.path.exists(search_path):
path = os.path.join(search_path, model_path) path = os.path.join(search_path, model_path)
@ -556,7 +558,8 @@ class DiffusersLoader:
with comfy_tqdm(): with comfy_tqdm():
model_path = snapshot_download(model_path) model_path = snapshot_download(model_path)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) model_options = get_model_options_for_dtype(weight_dtype)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@ -875,6 +878,14 @@ class ControlNetApplyAdvanced:
out.append(c) out.append(c)
return (out[0], out[1]) return (out[0], out[1])
def get_model_options_for_dtype(weight_dtype):
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
return model_options
class UNETLoader: class UNETLoader:
@classmethod @classmethod
@ -888,16 +899,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name, weight_dtype):
model_options = {} model_options = get_model_options_for_dtype(weight_dtype)
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS) unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options) model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -19,15 +19,19 @@ from __future__ import annotations
import contextlib import contextlib
import itertools import itertools
import json
import logging import logging
import math import math
import os
import random import random
import struct import struct
import sys import sys
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Any from typing import Optional, Any
import accelerate
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
import torch import torch
@ -55,13 +59,27 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt: str, safe_load=False, device=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
if ckpt is None: if ckpt is None:
raise FileNotFoundError("the checkpoint was not found") raise FileNotFoundError("the checkpoint was not found")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type) sd = safetensors.torch.load_file(ckpt, device=device.type)
elif ckpt.lower().endswith("index.json"):
# from accelerate
index_filename = ckpt
checkpoint_folder = os.path.split(index_filename)[0]
with open(index_filename) as f:
index = json.loads(f.read())
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values())))
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
sd: dict[str, torch.Tensor] = {}
for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
else: else:
if safe_load: if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames: if not 'weights_only' in torch.load.__code__.co_varnames:

View File

@ -1,8 +1,11 @@
import logging import logging
import torch import torch
from torch.nn import LayerNorm
from comfy import model_management
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model" DIFFUSION_MODEL = "diffusion_model"
@ -47,6 +50,65 @@ class TorchCompileModel:
return model, return model,
class QuantizeModel(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
"strategy": (["torchao", "quanto"], {"default": "torchao"})
}
}
FUNCTION = "execute"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
RETURN_TYPES = ("MODEL",)
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# default exclusions
_unused_exclusions = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}
if strategy == "quanto":
from optimum.quanto import quantize, qint8
exclusion_list = [
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
]
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet
elif strategy == "torchao":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# def filter_fn(module: torch.nn.Module, name: str):
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
_in_place_fixme = unet
else:
raise ValueError(f"unknown strategy {strategy}")
model.add_object_patch("diffusion_model", _in_place_fixme)
return model,
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel, "TorchCompileModel": TorchCompileModel,
"QuantizeModel": QuantizeModel,
} }

View File

@ -88,18 +88,14 @@ class UpscaleModelManageable(ModelManageable):
def model_dtype(self) -> torch.dtype: def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
self.model.to(device=device_to) self.model.to(device=device_to)
return self.model return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module: def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to) self.model.to(device=device_to)
return self.model return self.model
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=offload_device)
return self.model
def __str__(self): def __str__(self):
if self.ckpt_name is not None: if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>" return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"