mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Memory management and compilation improvements
- Experimental support for sage attention on Linux - Diffusers loader now supports model indices - Transformers model management now aligns with updates to ComfyUI - Flux layers correctly use unbind - Add float8 support for model loading in more places - Experimental quantization approaches from Quanto and torchao - Model upscaling interacts with memory management better This update also disables ROCm testing because it isn't reliable enough on consumer hardware. ROCm is not really supported by the 7600.
This commit is contained in:
parent
0a25b67ff8
commit
bbe2ed330c
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -52,8 +52,6 @@ jobs:
|
||||
runner:
|
||||
- labels: [self-hosted, Linux, X64, cpu]
|
||||
container: "ubuntu:22.04"
|
||||
- labels: [self-hosted, Linux, X64, rocm-7600-8gb]
|
||||
container: "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0"
|
||||
- labels: [self-hosted, Linux, X64, cuda-3060-12gb]
|
||||
container: "nvcr.io/nvidia/pytorch:24.03-py3"
|
||||
steps:
|
||||
|
||||
@ -99,6 +99,8 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Use the new pytorch 2.0 cross attention function.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention")
|
||||
parser.add_argument("--disable-sage-attention", action="store_true", help="Disable Sage Attention")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||
|
||||
@ -75,6 +75,8 @@ class Configuration(dict):
|
||||
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
||||
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
||||
disable_xformers (bool): Disable xformers.
|
||||
disable_flash_attn (bool): Disable flash_attn package attention.
|
||||
disable_sage_attention (bool): Disable sage attention package attention.
|
||||
gpu_only (bool): Run everything on the GPU.
|
||||
highvram (bool): Keep models in GPU memory.
|
||||
normalvram (bool): Default VRAM usage setting.
|
||||
@ -157,6 +159,8 @@ class Configuration(dict):
|
||||
self.use_quad_cross_attention: bool = False
|
||||
self.use_pytorch_cross_attention: bool = False
|
||||
self.disable_xformers: bool = False
|
||||
self.disable_flash_attn: bool = False
|
||||
self.disable_sage_attention: bool = False
|
||||
self.gpu_only: bool = False
|
||||
self.highvram: bool = False
|
||||
self.normalvram: bool = False
|
||||
|
||||
@ -13,9 +13,17 @@ def first_file(path, filenames) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
|
||||
diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
|
||||
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
|
||||
def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = {}
|
||||
diffusion_model_names = [
|
||||
"diffusion_pytorch_model.fp16.safetensors",
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
"diffusion_pytorch_model.fp16.bin",
|
||||
"diffusion_pytorch_model.bin",
|
||||
"diffusion_pytorch_model.safetensors.index.json"
|
||||
]
|
||||
unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) or first_file(os.path.join(model_path, "transformer"), diffusion_model_names)
|
||||
vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
|
||||
|
||||
text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
|
||||
@ -28,7 +36,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
|
||||
unet = None
|
||||
if unet_path is not None:
|
||||
unet = sd.load_diffusion_model(unet_path)
|
||||
unet = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
|
||||
clip = None
|
||||
textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"])
|
||||
|
||||
@ -79,13 +79,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
|
||||
# if we have flash-attn installed, try to use it
|
||||
try:
|
||||
import flash_attn
|
||||
attn_override_kwargs = {
|
||||
"attn_implementation": "flash_attention_2",
|
||||
**kwargs_to_try[0]
|
||||
}
|
||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||
if model_management.flash_attn_enabled():
|
||||
attn_override_kwargs = {
|
||||
"attn_implementation": "flash_attention_2",
|
||||
**kwargs_to_try[0]
|
||||
}
|
||||
kwargs_to_try = (attn_override_kwargs, *kwargs_to_try)
|
||||
logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried")
|
||||
except ImportError:
|
||||
pass
|
||||
for i, props in enumerate(kwargs_to_try):
|
||||
@ -303,16 +303,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
return self.model.to(device=offload_device)
|
||||
|
||||
def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel:
|
||||
model = copy.copy(self)
|
||||
model._processor = processor
|
||||
|
||||
@ -149,14 +149,16 @@ class DoubleStreamBlock(nn.Module):
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k, img_v = torch.unbind(img_qkv, dim=0)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
@ -221,7 +223,8 @@ class SingleStreamBlock(nn.Module):
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = torch.unbind(qkv, dim=0)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import math
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from typing import Optional
|
||||
import logging
|
||||
from torch import nn, einsum
|
||||
|
||||
from .diffusionmodules.util import AlphaBlender, timestep_embedding
|
||||
from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
@ -12,14 +14,22 @@ from ... import model_management
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn
|
||||
|
||||
if model_management.flash_attn_enabled():
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
|
||||
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
@ -27,12 +37,13 @@ def get_attn_precision(attn_precision):
|
||||
return FORCE_UPCAST_ATTENTION_DTYPE
|
||||
return attn_precision
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
@ -82,9 +93,11 @@ class FeedForward(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
@ -98,7 +111,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
@ -122,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
if exists(mask):
|
||||
if mask.dtype == torch.bool:
|
||||
mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention
|
||||
mask = rearrange(mask, 'b ... -> b (...)') # TODO: check if this bool part matches pytorch attention
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
@ -167,13 +180,12 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head)
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
bytes_per_token = torch.finfo(torch.float32).bits // 8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
bytes_per_token = torch.finfo(query.dtype).bits // 8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
@ -215,9 +227,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
@ -231,7 +244,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
@ -262,16 +275,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
|
||||
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
|
||||
|
||||
if steps > 64:
|
||||
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
|
||||
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
|
||||
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
|
||||
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
|
||||
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
@ -289,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
with torch.autocast(enabled=False, device_type='cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale
|
||||
@ -331,11 +343,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return r1
|
||||
|
||||
BROKEN_XFORMERS = False
|
||||
if model_management.xformers_enabled():
|
||||
x_vers = xformers.__version__
|
||||
# XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
|
||||
BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
@ -346,10 +353,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
disabled_xformers = False
|
||||
|
||||
if BROKEN_XFORMERS:
|
||||
if b * heads > 65535:
|
||||
disabled_xformers = True
|
||||
|
||||
if not disabled_xformers:
|
||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
disabled_xformers = True
|
||||
@ -358,7 +361,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
@ -390,22 +393,36 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
def pytorch_style_decl(func):
|
||||
@wraps(func)
|
||||
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
out = func(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape)
|
||||
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
return wrapper
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return flash_attn_func(q, k, v)
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
@ -426,10 +443,11 @@ else:
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
if small_input:
|
||||
if model_management.pytorch_attention_enabled():
|
||||
return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
return attention_pytorch # TODO: need to confirm but this is probably slightly faster for small inputs in all cases
|
||||
else:
|
||||
return attention_basic
|
||||
|
||||
@ -493,7 +511,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
@ -507,7 +525,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
@ -641,6 +659,7 @@ class SpatialTransformer(nn.Module):
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
@ -653,23 +672,23 @@ class SpatialTransformer(nn.Module):
|
||||
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
if not use_linear:
|
||||
self.proj_in = operations.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
for d in range(depth)]
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
self.proj_out = operations.Conv2d(inner_dim,in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Conv2d(inner_dim, in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device)
|
||||
self.use_linear = use_linear
|
||||
@ -699,27 +718,27 @@ class SpatialTransformer(nn.Module):
|
||||
|
||||
class SpatialVideoTransformer(SpatialTransformer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.0,
|
||||
use_linear=False,
|
||||
context_dim=None,
|
||||
use_spatial_context=False,
|
||||
timesteps=None,
|
||||
merge_strategy: str = "fixed",
|
||||
merge_factor: float = 0.5,
|
||||
time_context_dim=None,
|
||||
ff_in=False,
|
||||
checkpoint=False,
|
||||
time_depth=1,
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
in_channels,
|
||||
@ -785,13 +804,13 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
transformer_options={}
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
time_context: Optional[torch.Tensor] = None,
|
||||
timesteps: Optional[int] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
transformer_options={}
|
||||
) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
x_in = x
|
||||
@ -801,7 +820,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
|
||||
if self.use_spatial_context:
|
||||
assert (
|
||||
context.ndim == 3
|
||||
context.ndim == 3
|
||||
), f"n dims of spatial context should be 3 but are {context.ndim}"
|
||||
|
||||
if time_context is None:
|
||||
@ -830,7 +849,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
emb = emb[:, None, :]
|
||||
|
||||
for it_, (block, mix_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
zip(self.transformer_blocks, self.time_stack)
|
||||
):
|
||||
transformer_options["block_index"] = it_
|
||||
x = block(
|
||||
@ -844,7 +863,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
|
||||
B, S, C = x_mix.shape
|
||||
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
|
||||
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
|
||||
x_mix = mix_block(x_mix, context=time_context) # TODO: transformer_options
|
||||
x_mix = rearrange(
|
||||
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
|
||||
)
|
||||
@ -858,5 +877,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -108,7 +108,6 @@ class BaseModel(torch.nn.Module):
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if model_management.force_channels_last():
|
||||
# todo: ???
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
|
||||
@ -535,10 +535,6 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
|
||||
for user_dir in Path(local_dir_root).iterdir():
|
||||
for model_dir in user_dir.iterdir():
|
||||
try:
|
||||
_hf_fs.resolve_path(str(user_dir / model_dir))
|
||||
except Exception as exc_info:
|
||||
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
|
||||
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
|
||||
|
||||
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
|
||||
|
||||
@ -23,7 +23,7 @@ import sys
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from threading import RLock
|
||||
from typing import Literal, List, Sequence
|
||||
from typing import Literal, List, Sequence, Final
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
@ -128,6 +128,9 @@ def get_torch_device():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
else:
|
||||
try:
|
||||
# https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3
|
||||
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
|
||||
torch.set_float32_matmul_precision("high")
|
||||
return torch.device(torch.cuda.current_device())
|
||||
except:
|
||||
warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device")
|
||||
@ -319,7 +322,7 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
current_loaded_models: List["LoadedModel"] = []
|
||||
current_loaded_models: Final[List["LoadedModel"]] = []
|
||||
|
||||
|
||||
def module_size(module):
|
||||
@ -974,6 +977,22 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
else:
|
||||
return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking)
|
||||
|
||||
FLASH_ATTENTION_ENABLED = False
|
||||
if not args.disable_flash_attn:
|
||||
try:
|
||||
import flash_attn
|
||||
FLASH_ATTENTION_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
SAGE_ATTENTION_ENABLED = False
|
||||
if not args.disable_sage_attention:
|
||||
try:
|
||||
import sageattention
|
||||
SAGE_ATTENTION_ENABLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def xformers_enabled():
|
||||
global directml_device
|
||||
@ -986,6 +1005,30 @@ def xformers_enabled():
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
|
||||
def flash_attn_enabled():
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
return FLASH_ATTENTION_ENABLED
|
||||
|
||||
def sage_attention_enabled():
|
||||
global directml_device
|
||||
global cpu_state
|
||||
if cpu_state != CPUState.GPU:
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
if xformers_enabled():
|
||||
return False
|
||||
return SAGE_ATTENTION_ENABLED
|
||||
|
||||
|
||||
def xformers_enabled_vae():
|
||||
enabled = xformers_enabled()
|
||||
|
||||
@ -55,17 +55,13 @@ class ModelManageable(Protocol):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return next(self.model.parameters()).dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.patch_model(device_to=device_to, patch_weights=False)
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
"""
|
||||
Unloads the model by moving it to the offload device
|
||||
:param offload_device:
|
||||
:param device_to:
|
||||
:param unpatch_weights:
|
||||
:return:
|
||||
"""
|
||||
@ -99,6 +95,20 @@ class ModelManageable(Protocol):
|
||||
def current_loaded_device(self) -> torch.device:
|
||||
return self.current_device
|
||||
|
||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||
from . import utils
|
||||
return utils.get_attr(self.model, name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
if not hasattr(self, "_model_options"):
|
||||
setattr(self, "_model_options", {"transformer_options": {}})
|
||||
return getattr(self, "_model_options")
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value):
|
||||
setattr(self, "_model_options", value)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MemoryMeasurements:
|
||||
|
||||
@ -27,10 +27,11 @@ import torch.nn
|
||||
|
||||
from . import model_management, lora
|
||||
from . import utils
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
from .float import stochastic_rounding
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements
|
||||
from .comfy_types import UnetWrapperFunction
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
@ -45,6 +46,7 @@ def string_to_seed(data):
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
@ -106,7 +108,7 @@ class ModelPatcher(ModelManageable):
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self.model_options = {"transformer_options": {}}
|
||||
self._model_options = {"transformer_options": {}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
@ -115,6 +117,14 @@ class ModelPatcher(ModelManageable):
|
||||
self.ckpt_name = ckpt_name
|
||||
self._memory_measurements = MemoryMeasurements(self.model)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
return self._model_options
|
||||
|
||||
@model_options.setter
|
||||
def model_options(self, value):
|
||||
self._model_options = value
|
||||
|
||||
@property
|
||||
def model_device(self) -> torch.device:
|
||||
return self._memory_measurements.device
|
||||
@ -145,7 +155,7 @@ class ModelPatcher(ModelManageable):
|
||||
n.patches_uuid = self.patches_uuid
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n._model_options = copy.deepcopy(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
return n
|
||||
@ -260,6 +270,11 @@ class ModelPatcher(ModelManageable):
|
||||
self.model_options["model_function_wrapper"] = wrap_func.to(device)
|
||||
|
||||
def model_dtype(self):
|
||||
# this pokes into the internals of diffusion model a little bit
|
||||
# todo: the base model isn't going to be aware that its diffusion model is patched this way
|
||||
if isinstance(self.model, BaseModel):
|
||||
diffusion_model = self.get_model_object("diffusion_model")
|
||||
return diffusion_model.dtype
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
|
||||
@ -293,7 +308,7 @@ class ModelPatcher(ModelManageable):
|
||||
if filter_prefix is not None:
|
||||
if not k.startswith(filter_prefix):
|
||||
continue
|
||||
bk = self.backup.get(k, None)
|
||||
bk: torch.nn.Module | None = self.backup.get(k, None)
|
||||
if bk is not None:
|
||||
weight = bk.weight
|
||||
else:
|
||||
@ -494,7 +509,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
bk: torch.nn.Module | None = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
utils.copy_to_param(self.model, key, bk.weight)
|
||||
|
||||
@ -538,14 +538,16 @@ class DiffusersLoader:
|
||||
|
||||
paths += get_huggingface_repo_list()
|
||||
paths = list(frozenset(paths))
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
return {"required": {"model_path": (paths,),
|
||||
"weight_dtype": (FLUX_WEIGHT_DTYPES,)
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True):
|
||||
def load_checkpoint(self, model_path, output_vae=True, output_clip=True,weight_dtype:str="default"):
|
||||
for search_path in folder_paths.get_folder_paths("diffusers"):
|
||||
if os.path.exists(search_path):
|
||||
path = os.path.join(search_path, model_path)
|
||||
@ -556,7 +558,8 @@ class DiffusersLoader:
|
||||
with comfy_tqdm():
|
||||
model_path = snapshot_download(model_path)
|
||||
|
||||
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
model_options = get_model_options_for_dtype(weight_dtype)
|
||||
return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options)
|
||||
|
||||
|
||||
class unCLIPCheckpointLoader:
|
||||
@ -875,6 +878,14 @@ class ControlNetApplyAdvanced:
|
||||
out.append(c)
|
||||
return (out[0], out[1])
|
||||
|
||||
def get_model_options_for_dtype(weight_dtype):
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
return model_options
|
||||
|
||||
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
@ -888,16 +899,14 @@ class UNETLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
model_options = get_model_options_for_dtype(weight_dtype)
|
||||
|
||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
|
||||
|
||||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
||||
@ -19,15 +19,19 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any
|
||||
|
||||
import accelerate
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -55,13 +59,27 @@ def _get_progress_bar_enabled():
|
||||
setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled))
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
def load_torch_file(ckpt: str, safe_load=False, device=None):
|
||||
if device is None:
|
||||
device = torch.device("cpu")
|
||||
if ckpt is None:
|
||||
raise FileNotFoundError("the checkpoint was not found")
|
||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||
sd = safetensors.torch.load_file(ckpt, device=device.type)
|
||||
elif ckpt.lower().endswith("index.json"):
|
||||
# from accelerate
|
||||
index_filename = ckpt
|
||||
checkpoint_folder = os.path.split(index_filename)[0]
|
||||
with open(index_filename) as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
if "weight_map" in index:
|
||||
index = index["weight_map"]
|
||||
checkpoint_files = sorted(list(set(index.values())))
|
||||
checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
|
||||
sd: dict[str, torch.Tensor] = {}
|
||||
for checkpoint_file in checkpoint_files:
|
||||
sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type))
|
||||
else:
|
||||
if safe_load:
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
|
||||
@ -47,6 +50,65 @@ class TorchCompileModel:
|
||||
return model,
|
||||
|
||||
|
||||
class QuantizeModel(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {}),
|
||||
"strategy": (["torchao", "quanto"], {"default": "torchao"})
|
||||
}
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
|
||||
def execute(self, model: ModelPatcher, strategy: str = "torchao"):
|
||||
logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.")
|
||||
logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations")
|
||||
model = model.clone()
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
# todo: quantize quantizes in place, which is not desired
|
||||
|
||||
# default exclusions
|
||||
_unused_exclusions = {
|
||||
"time_embedding.",
|
||||
"add_embedding.",
|
||||
"time_in.",
|
||||
"txt_in.",
|
||||
"vector_in.",
|
||||
"img_in.",
|
||||
"guidance_in.",
|
||||
"final_layer.",
|
||||
}
|
||||
if strategy == "quanto":
|
||||
from optimum.quanto import quantize, qint8
|
||||
exclusion_list = [
|
||||
name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None
|
||||
]
|
||||
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
||||
_in_place_fixme = unet
|
||||
elif strategy == "torchao":
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight
|
||||
model = model.clone()
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
# todo: quantize quantizes in place, which is not desired
|
||||
|
||||
# def filter_fn(module: torch.nn.Module, name: str):
|
||||
# return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions)
|
||||
quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device())
|
||||
_in_place_fixme = unet
|
||||
else:
|
||||
raise ValueError(f"unknown strategy {strategy}")
|
||||
|
||||
model.add_object_patch("diffusion_model", _in_place_fixme)
|
||||
return model,
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TorchCompileModel": TorchCompileModel,
|
||||
"QuantizeModel": QuantizeModel,
|
||||
}
|
||||
|
||||
@ -88,18 +88,14 @@ class UpscaleModelManageable(ModelManageable):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return next(self.model.parameters()).dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module:
|
||||
self.model.to(device=device_to)
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.model.to(device=device_to)
|
||||
return self.model
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
self.model.to(device=offload_device)
|
||||
return self.model
|
||||
|
||||
def __str__(self):
|
||||
if self.ckpt_name is not None:
|
||||
return f"<UpscaleModelManageable for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user