Memory management and compilation improvements

- Experimental support for sage attention on Linux
 - Diffusers loader now supports model indices
 - Transformers model management now aligns with updates to ComfyUI
 - Flux layers correctly use unbind
 - Add float8 support for model loading in more places
 - Experimental quantization approaches from Quanto and torchao
 - Model upscaling interacts with memory management better

This update also disables ROCm testing because it isn't reliable enough
on consumer hardware. ROCm is not really supported by the 7600.
This commit is contained in:
doctorpangloss 2024-10-09 09:11:22 -07:00
parent 0a25b67ff8
commit bbe2ed330c
16 changed files with 319 additions and 142 deletions

View File

@ -52,8 +52,6 @@ jobs:
runner:
- labels: [self-hosted, Linux, X64, cpu]
container: "ubuntu:22.04"
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
container: "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0"
- labels: [self-hosted, Linux, X64, cuda-3060-12gb]
container: "nvcr.io/nvidia/pytorch:24.03-py3"
steps:

View File

@ -99,6 +99,8 @@ def _create_parser() -> EnhancedConfigArgParser:
help="Use the new pytorch 2.0 cross attention function.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention")
parser.add_argument("--disable-sage-attention", action="store_true", help="Disable Sage Attention")
upcast = parser.add_mutually_exclusive_group()
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")

View File

@ -75,6 +75,8 @@ class Configuration(dict):
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
disable_xformers (bool): Disable xformers.
disable_flash_attn (bool): Disable flash_attn package attention.
disable_sage_attention (bool): Disable sage attention package attention.
gpu_only (bool): Run everything on the GPU.
highvram (bool): Keep models in GPU memory.
normalvram (bool): Default VRAM usage setting.
@ -157,6 +159,8 @@ class Configuration(dict):
self.use_quad_cross_attention: bool = False
self.use_pytorch_cross_attention: bool = False
self.disable_xformers: bool = False
self.disable_flash_attn: bool = False
self.disable_sage_attention: bool = False
self.gpu_only: bool = False
self.highvram: bool = False
self.normalvram: bool = False

View File

@ -13,9 +13,17 @@ def first_file(path, filenames) -> str | None:
return None
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None, model_options=None):
if model_options is None:
model_options = {}
diffusion_model_names = [
"diffusion_pytorch_model.fp16.safetensors",
"diffusion_pytorch_model.safetensors",
"diffusion_pytorch_model.fp16.bin",
"diffusion_pytorch_model.bin",
"diffusion_pytorch_model.safetensors.index.json"
]
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) or first_file(os.path.join(model_path, "transformer"), diffusion_model_names)
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
@ -28,7 +36,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
unet = None
if unet_path is not None:
unet = sd.load_diffusion_model(unet_path)
unet = sd.load_diffusion_model(unet_path, model_options=model_options)
clip = None
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])

View File

@ -79,13 +79,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
# if we have flash-attn installed, try to use it
try:
import flash_attn
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
if model_management.flash_attn_enabled():
attn_override_kwargs = {
"attn_implementation": "flash_attention_2",
**kwargs_to_try[0]
}
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
except ImportError:
pass
for i, props in enumerate(kwargs_to_try):
@ -303,16 +303,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
def model_dtype(self) -> torch.dtype:
return self.model.dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
warnings.warn("Transformers models do not currently support adapters like LoRAs")
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
return self.model.to(device=device_to)
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=device_to)
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
return self.model.to(device=offload_device)
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
model = copy.copy(self)
model._processor = processor

View File

@ -149,14 +149,16 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k, img_v = torch.unbind(img_qkv, dim=0)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
@ -221,7 +223,8 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = torch.unbind(qkv, dim=0)
q, k = self.norm(q, k, v)
# compute attention

View File

@ -1,10 +1,12 @@
import logging
import math
from functools import wraps
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
import logging
from torch import nn, einsum
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@ -12,14 +14,22 @@ from ... import model_management
if model_management.xformers_enabled():
import xformers # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
import xformers.ops # pylint: disable=import-error
if model_management.sage_attention_enabled():
from sageattention import sageattn
if model_management.flash_attn_enabled():
from flash_attn import flash_attn_func
from ...cli_args import args
from ... import ops
ops = ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
def get_attn_precision(attn_precision):
if args.dont_upcast_attention:
return None
@ -27,12 +37,13 @@ def get_attn_precision(attn_precision):
return FORCE_UPCAST_ATTENTION_DTYPE
return attn_precision
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@ -82,9 +93,11 @@ class FeedForward(nn.Module):
def forward(self, x):
return self.net(x)
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
@ -98,7 +111,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
h = heads
if skip_reshape:
q, k, v = map(
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
@ -122,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if exists(mask):
if mask.dtype == torch.bool:
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
mask = rearrange(mask, 'b ... -> b (...)') # TODO: check if this bool part matches pytorch attention
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
@ -167,13 +180,12 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
dtype = query.dtype
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
if upcast_attention:
bytes_per_token = torch.finfo(torch.float32).bits//8
bytes_per_token = torch.finfo(torch.float32).bits // 8
else:
bytes_per_token = torch.finfo(query.dtype).bits//8
bytes_per_token = torch.finfo(query.dtype).bits // 8
batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
@ -215,9 +227,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.to(dtype)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
attn_precision = get_attn_precision(attn_precision)
@ -231,7 +244,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
h = heads
if skip_reshape:
q, k, v = map(
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
@ -262,16 +275,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
if mask is not None:
if len(mask.shape) == 2:
@ -289,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
if upcast:
with torch.autocast(enabled=False, device_type = 'cuda'):
with torch.autocast(enabled=False, device_type='cuda'):
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
else:
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
@ -331,11 +343,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return r1
BROKEN_XFORMERS = False
if model_management.xformers_enabled():
x_vers = xformers.__version__
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
@ -346,10 +353,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = False
if BROKEN_XFORMERS:
if b * heads > 65535:
disabled_xformers = True
if not disabled_xformers:
if torch.jit.is_tracing() or torch.jit.is_scripting():
disabled_xformers = True
@ -358,7 +361,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
if skip_reshape:
q, k, v = map(
q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head),
(q, k, v),
)
@ -390,22 +393,36 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
def pytorch_style_decl(func):
@wraps(func)
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
out = func(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
return out
return wrapper
@pytorch_style_decl
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@pytorch_style_decl
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
@pytorch_style_decl
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
return flash_attn_func(q, k, v)
optimized_attention = attention_basic
@ -426,10 +443,11 @@ else:
optimized_attention_masked = optimized_attention
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
return attention_pytorch # TODO: need to confirm but this is probably slightly faster for small inputs in all cases
else:
return attention_basic
@ -493,7 +511,7 @@ class BasicTransformerBlock(nn.Module):
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
if disable_temporal_crossattention:
@ -507,7 +525,7 @@ class BasicTransformerBlock(nn.Module):
context_dim_attn2 = context_dim
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
@ -641,6 +659,7 @@ class SpatialTransformer(nn.Module):
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
@ -653,23 +672,23 @@ class SpatialTransformer(nn.Module):
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
if not use_linear:
self.proj_in = operations.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
inner_dim,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
else:
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
for d in range(depth)]
for d in range(depth)]
)
if not use_linear:
self.proj_out = operations.Conv2d(inner_dim,in_channels,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
self.proj_out = operations.Conv2d(inner_dim, in_channels,
kernel_size=1,
stride=1,
padding=0, dtype=dtype, device=device)
else:
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
self.use_linear = use_linear
@ -699,27 +718,27 @@ class SpatialTransformer(nn.Module):
class SpatialVideoTransformer(SpatialTransformer):
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
use_linear=False,
context_dim=None,
use_spatial_context=False,
timesteps=None,
merge_strategy: str = "fixed",
merge_factor: float = 0.5,
time_context_dim=None,
ff_in=False,
checkpoint=False,
time_depth=1,
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
attn_precision=None,
dtype=None, device=None, operations=ops
):
super().__init__(
in_channels,
@ -785,13 +804,13 @@ class SpatialVideoTransformer(SpatialTransformer):
)
def forward(
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
self,
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
time_context: Optional[torch.Tensor] = None,
timesteps: Optional[int] = None,
image_only_indicator: Optional[torch.Tensor] = None,
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
x_in = x
@ -801,7 +820,7 @@ class SpatialVideoTransformer(SpatialTransformer):
if self.use_spatial_context:
assert (
context.ndim == 3
context.ndim == 3
), f"n dims of spatial context should be 3 but are {context.ndim}"
if time_context is None:
@ -830,7 +849,7 @@ class SpatialVideoTransformer(SpatialTransformer):
emb = emb[:, None, :]
for it_, (block, mix_block) in enumerate(
zip(self.transformer_blocks, self.time_stack)
zip(self.transformer_blocks, self.time_stack)
):
transformer_options["block_index"] = it_
x = block(
@ -844,7 +863,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = mix_block(x_mix, context=time_context) # TODO: transformer_options
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)
@ -858,5 +877,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -108,7 +108,6 @@ class BaseModel(torch.nn.Module):
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last():
# todo: ???
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))

View File

@ -535,10 +535,6 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
for user_dir in Path(local_dir_root).iterdir():
for model_dir in user_dir.iterdir():
try:
_hf_fs.resolve_path(str(user_dir / model_dir))
except Exception as exc_info:
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)

View File

@ -23,7 +23,7 @@ import sys
import warnings
from enum import Enum
from threading import RLock
from typing import Literal, List, Sequence
from typing import Literal, List, Sequence, Final
import psutil
import torch
@ -128,6 +128,9 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
else:
try:
# https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")
return torch.device(torch.cuda.current_device())
except:
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
@ -319,7 +322,7 @@ try:
except:
logging.warning("Could not pick default device.")
current_loaded_models: List["LoadedModel"] = []
current_loaded_models: Final[List["LoadedModel"]] = []
def module_size(module):
@ -974,6 +977,22 @@ def cast_to_device(tensor, device, dtype, copy=False):
else:
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
FLASH_ATTENTION_ENABLED = False
if not args.disable_flash_attn:
try:
import flash_attn
FLASH_ATTENTION_ENABLED = True
except ImportError:
pass
SAGE_ATTENTION_ENABLED = False
if not args.disable_sage_attention:
try:
import sageattention
SAGE_ATTENTION_ENABLED = True
except ImportError:
pass
def xformers_enabled():
global directml_device
@ -986,6 +1005,30 @@ def xformers_enabled():
return False
return XFORMERS_IS_AVAILABLE
def flash_attn_enabled():
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_device:
return False
return FLASH_ATTENTION_ENABLED
def sage_attention_enabled():
global directml_device
global cpu_state
if cpu_state != CPUState.GPU:
return False
if is_intel_xpu():
return False
if directml_device:
return False
if xformers_enabled():
return False
return SAGE_ATTENTION_ENABLED
def xformers_enabled_vae():
enabled = xformers_enabled()

View File

@ -55,17 +55,13 @@ class ModelManageable(Protocol):
def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
self.patch_model(device_to=device_to, patch_weights=False)
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
...
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
"""
Unloads the model by moving it to the offload device
:param offload_device:
:param device_to:
:param unpatch_weights:
:return:
"""
@ -99,6 +95,20 @@ class ModelManageable(Protocol):
def current_loaded_device(self) -> torch.device:
return self.current_device
def get_model_object(self, name: str) -> torch.nn.Module:
from . import utils
return utils.get_attr(self.model, name)
@property
def model_options(self) -> dict:
if not hasattr(self, "_model_options"):
setattr(self, "_model_options", {"transformer_options": {}})
return getattr(self, "_model_options")
@model_options.setter
def model_options(self, value):
setattr(self, "_model_options", value)
@dataclasses.dataclass
class MemoryMeasurements:

View File

@ -27,10 +27,11 @@ import torch.nn
from . import model_management, lora
from . import utils
from .comfy_types import UnetWrapperFunction
from .float import stochastic_rounding
from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements
from .comfy_types import UnetWrapperFunction
def string_to_seed(data):
crc = 0xFFFFFFFF
@ -45,6 +46,7 @@ def string_to_seed(data):
crc >>= 1
return crc ^ 0xFFFFFFFF
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -106,7 +108,7 @@ class ModelPatcher(ModelManageable):
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self.model_options = {"transformer_options": {}}
self._model_options = {"transformer_options": {}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
@ -115,6 +117,14 @@ class ModelPatcher(ModelManageable):
self.ckpt_name = ckpt_name
self._memory_measurements = MemoryMeasurements(self.model)
@property
def model_options(self) -> dict:
return self._model_options
@model_options.setter
def model_options(self, value):
self._model_options = value
@property
def model_device(self) -> torch.device:
return self._memory_measurements.device
@ -145,7 +155,7 @@ class ModelPatcher(ModelManageable):
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n._model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n
@ -260,6 +270,11 @@ class ModelPatcher(ModelManageable):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_dtype(self):
# this pokes into the internals of diffusion model a little bit
# todo: the base model isn't going to be aware that its diffusion model is patched this way
if isinstance(self.model, BaseModel):
diffusion_model = self.get_model_object("diffusion_model")
return diffusion_model.dtype
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
@ -293,7 +308,7 @@ class ModelPatcher(ModelManageable):
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
bk = self.backup.get(k, None)
bk: torch.nn.Module | None = self.backup.get(k, None)
if bk is not None:
weight = bk.weight
else:
@ -494,7 +509,7 @@ class ModelPatcher(ModelManageable):
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
bk: torch.nn.Module | None = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight)

View File

@ -538,14 +538,16 @@ class DiffusersLoader:
paths += get_huggingface_repo_list()
paths = list(frozenset(paths))
return {"required": {"model_path": (paths,), }}
return {"required": {"model_path": (paths,),
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "advanced/loaders"
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
def load_checkpoint(self, model_path, output_vae=True, output_clip=True,weight_dtype:str="default"):
for search_path in folder_paths.get_folder_paths("diffusers"):
if os.path.exists(search_path):
path = os.path.join(search_path, model_path)
@ -556,7 +558,8 @@ class DiffusersLoader:
with comfy_tqdm():
model_path = snapshot_download(model_path)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model_options = get_model_options_for_dtype(weight_dtype)
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
class unCLIPCheckpointLoader:
@ -875,6 +878,14 @@ class ControlNetApplyAdvanced:
out.append(c)
return (out[0], out[1])
def get_model_options_for_dtype(weight_dtype):
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
return model_options
class UNETLoader:
@classmethod
@ -888,16 +899,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
model_options = {}
if weight_dtype == "fp8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
model_options = get_model_options_for_dtype(weight_dtype)
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):

View File

@ -19,15 +19,19 @@ from __future__ import annotations
import contextlib
import itertools
import json
import logging
import math
import os
import random
import struct
import sys
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Optional, Any
import accelerate
import numpy as np
import safetensors.torch
import torch
@ -55,13 +59,27 @@ def _get_progress_bar_enabled():
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
def load_torch_file(ckpt, safe_load=False, device=None):
def load_torch_file(ckpt: str, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt is None:
raise FileNotFoundError("the checkpoint was not found")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
elif ckpt.lower().endswith("index.json"):
# from accelerate
index_filename = ckpt
checkpoint_folder = os.path.split(index_filename)[0]
with open(index_filename) as f:
index = json.loads(f.read())
if "weight_map" in index:
index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values())))
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
sd: dict[str, torch.Tensor] = {}
for checkpoint_file in checkpoint_files:
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:

View File

@ -1,8 +1,11 @@
import logging
import torch
from torch.nn import LayerNorm
from comfy import model_management
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
DIFFUSION_MODEL = "diffusion_model"
@ -47,6 +50,65 @@ class TorchCompileModel:
return model,
class QuantizeModel(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"model": ("MODEL", {}),
"strategy": (["torchao", "quanto"], {"default": "torchao"})
}
}
FUNCTION = "execute"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
RETURN_TYPES = ("MODEL",)
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# default exclusions
_unused_exclusions = {
"time_embedding.",
"add_embedding.",
"time_in.",
"txt_in.",
"vector_in.",
"img_in.",
"guidance_in.",
"final_layer.",
}
if strategy == "quanto":
from optimum.quanto import quantize, qint8
exclusion_list = [
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
]
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet
elif strategy == "torchao":
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
model = model.clone()
unet = model.get_model_object("diffusion_model")
# todo: quantize quantizes in place, which is not desired
# def filter_fn(module: torch.nn.Module, name: str):
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
_in_place_fixme = unet
else:
raise ValueError(f"unknown strategy {strategy}")
model.add_object_patch("diffusion_model", _in_place_fixme)
return model,
NODE_CLASS_MAPPINGS = {
"TorchCompileModel": TorchCompileModel,
"QuantizeModel": QuantizeModel,
}

View File

@ -88,18 +88,14 @@ class UpscaleModelManageable(ModelManageable):
def model_dtype(self) -> torch.dtype:
return next(self.model.parameters()).dtype
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
self.model.to(device=device_to)
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=device_to)
return self.model
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
self.model.to(device=offload_device)
return self.model
def __str__(self):
if self.ckpt_name is not None:
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"