diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 13931d53b..c9a0a2183 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,8 +52,6 @@ jobs: runner: - labels: [self-hosted, Linux, X64, cpu] container: "ubuntu:22.04" - - labels: [self-hosted, Linux, X64, rocm-7600-8gb] - container: "rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0" - labels: [self-hosted, Linux, X64, cuda-3060-12gb] container: "nvcr.io/nvidia/pytorch:24.03-py3" steps: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index ef54777f4..d436119a7 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -99,6 +99,8 @@ def _create_parser() -> EnhancedConfigArgParser: help="Use the new pytorch 2.0 cross attention function.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") + parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention") + parser.add_argument("--disable-sage-attention", action="store_true", help="Disable Sage Attention") upcast = parser.add_mutually_exclusive_group() upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 6cb9848ad..269d4bedc 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -75,6 +75,8 @@ class Configuration(dict): use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. disable_xformers (bool): Disable xformers. + disable_flash_attn (bool): Disable flash_attn package attention. + disable_sage_attention (bool): Disable sage attention package attention. gpu_only (bool): Run everything on the GPU. highvram (bool): Keep models in GPU memory. normalvram (bool): Default VRAM usage setting. @@ -157,6 +159,8 @@ class Configuration(dict): self.use_quad_cross_attention: bool = False self.use_pytorch_cross_attention: bool = False self.disable_xformers: bool = False + self.disable_flash_attn: bool = False + self.disable_sage_attention: bool = False self.gpu_only: bool = False self.highvram: bool = False self.normalvram: bool = False diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index d6a09d8d8..3cb63a21a 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -13,9 +13,17 @@ def first_file(path, filenames) -> str | None: return None -def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None): - diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"] - unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) +def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None, model_options=None): + if model_options is None: + model_options = {} + diffusion_model_names = [ + "diffusion_pytorch_model.fp16.safetensors", + "diffusion_pytorch_model.safetensors", + "diffusion_pytorch_model.fp16.bin", + "diffusion_pytorch_model.bin", + "diffusion_pytorch_model.safetensors.index.json" + ] + unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names) or first_file(os.path.join(model_path, "transformer"), diffusion_model_names) vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names) text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"] @@ -28,7 +36,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire unet = None if unet_path is not None: - unet = sd.load_diffusion_model(unet_path) + unet = sd.load_diffusion_model(unet_path, model_options=model_options) clip = None textmodel_json_config1 = first_file(os.path.join(model_path, "text_encoder"), ["config.json"]) diff --git a/comfy/language/transformers_model_management.py b/comfy/language/transformers_model_management.py index f4db3800d..688cb6f81 100644 --- a/comfy/language/transformers_model_management.py +++ b/comfy/language/transformers_model_management.py @@ -79,13 +79,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel): # if we have flash-attn installed, try to use it try: - import flash_attn - attn_override_kwargs = { - "attn_implementation": "flash_attention_2", - **kwargs_to_try[0] - } - kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) - logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") + if model_management.flash_attn_enabled(): + attn_override_kwargs = { + "attn_implementation": "flash_attention_2", + **kwargs_to_try[0] + } + kwargs_to_try = (attn_override_kwargs, *kwargs_to_try) + logging.debug(f"while loading model {ckpt_name}, flash_attn was installed, so the flash_attention_2 implementation will be tried") except ImportError: pass for i, props in enumerate(kwargs_to_try): @@ -303,16 +303,13 @@ class TransformersManagedModel(ModelManageable, LanguageModel): def model_dtype(self) -> torch.dtype: return self.model.dtype - def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module: - warnings.warn("Transformers models do not currently support adapters like LoRAs") + + def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: return self.model.to(device=device_to) - def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module: + def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: return self.model.to(device=device_to) - def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: - return self.model.to(device=offload_device) - def patch_processor(self, processor: Any, overwrite_tokenizer: bool = False) -> TransformersManagedModel: model = copy.copy(self) model._processor = processor diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 0964086b5..11b241e2d 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -149,14 +149,16 @@ class DoubleStreamBlock(nn.Module): img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + img_q, img_k, img_v = torch.unbind(img_qkv, dim=0) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention @@ -221,7 +223,8 @@ class SingleStreamBlock(nn.Module): x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = torch.unbind(qkv, dim=0) q, k = self.norm(q, k, v) # compute attention diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a0cf51cf6..c704ed12d 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -1,10 +1,12 @@ +import logging import math +from functools import wraps +from typing import Optional + import torch import torch.nn.functional as F -from torch import nn, einsum from einops import rearrange, repeat -from typing import Optional -import logging +from torch import nn, einsum from .diffusionmodules.util import AlphaBlender, timestep_embedding from .sub_quadratic_attention import efficient_dot_product_attention @@ -12,14 +14,22 @@ from ... import model_management if model_management.xformers_enabled(): import xformers # pylint: disable=import-error - import xformers.ops # pylint: disable=import-error + import xformers.ops # pylint: disable=import-error + +if model_management.sage_attention_enabled(): + from sageattention import sageattn + +if model_management.flash_attn_enabled(): + from flash_attn import flash_attn_func from ...cli_args import args from ... import ops + ops = ops.disable_weight_init FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype() + def get_attn_precision(attn_precision): if args.dont_upcast_attention: return None @@ -27,12 +37,13 @@ def get_attn_precision(attn_precision): return FORCE_UPCAST_ATTENTION_DTYPE return attn_precision + def exists(val): return val is not None def uniq(arr): - return{el: True for el in arr}.keys() + return {el: True for el in arr}.keys() def default(val, d): @@ -82,9 +93,11 @@ class FeedForward(nn.Module): def forward(self, x): return self.net(x) + def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) + def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) @@ -98,7 +111,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape h = heads if skip_reshape: - q, k, v = map( + q, k, v = map( lambda t: t.reshape(b * heads, -1, dim_head), (q, k, v), ) @@ -122,7 +135,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if exists(mask): if mask.dtype == torch.bool: - mask = rearrange(mask, 'b ... -> b (...)') #TODO: check if this bool part matches pytorch attention + mask = rearrange(mask, 'b ... -> b (...)') # TODO: check if this bool part matches pytorch attention max_neg_value = -torch.finfo(sim.dtype).max mask = repeat(mask, 'b j -> (b h) () j', h=h) sim.masked_fill_(~mask, max_neg_value) @@ -167,13 +180,12 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) - dtype = query.dtype upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32 if upcast_attention: - bytes_per_token = torch.finfo(torch.float32).bits//8 + bytes_per_token = torch.finfo(torch.float32).bits // 8 else: - bytes_per_token = torch.finfo(query.dtype).bits//8 + bytes_per_token = torch.finfo(query.dtype).bits // 8 batch_x_heads, q_tokens, _ = query.shape _, _, k_tokens = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens @@ -215,9 +227,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, hidden_states = hidden_states.to(dtype) - hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) + hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1, 2).flatten(start_dim=2) return hidden_states + def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): attn_precision = get_attn_precision(attn_precision) @@ -231,7 +244,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape h = heads if skip_reshape: - q, k, v = map( + q, k, v = map( lambda t: t.reshape(b * heads, -1, dim_head), (q, k, v), ) @@ -262,16 +275,15 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape mem_required = tensor_size * modifier steps = 1 - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') if mask is not None: if len(mask.shape) == 2: @@ -289,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape for i in range(0, q.shape[1], slice_size): end = i + slice_size if upcast: - with torch.autocast(enabled=False, device_type = 'cuda'): + with torch.autocast(enabled=False, device_type='cuda'): s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale @@ -331,11 +343,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return r1 -BROKEN_XFORMERS = False -if model_management.xformers_enabled(): - x_vers = xformers.__version__ - # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) - BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): if skip_reshape: @@ -346,10 +353,6 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh disabled_xformers = False - if BROKEN_XFORMERS: - if b * heads > 65535: - disabled_xformers = True - if not disabled_xformers: if torch.jit.is_tracing() or torch.jit.is_scripting(): disabled_xformers = True @@ -358,7 +361,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) if skip_reshape: - q, k, v = map( + q, k, v = map( lambda t: t.reshape(b * heads, -1, dim_head), (q, k, v), ) @@ -390,22 +393,36 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh return out -def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): - if skip_reshape: - b, _, _, dim_head = q.shape - else: - b, _, dim_head = q.shape - dim_head //= heads - q, k, v = map( - lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), - (q, k, v), - ) +def pytorch_style_decl(func): + @wraps(func) + def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) - out = ( - out.transpose(1, 2).reshape(b, -1, heads * dim_head) - ) - return out + out = func(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + return wrapper + +@pytorch_style_decl +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + +@pytorch_style_decl +def attention_sagemaker(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + +@pytorch_style_decl +def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): + return flash_attn_func(q, k, v) optimized_attention = attention_basic @@ -426,10 +443,11 @@ else: optimized_attention_masked = optimized_attention + def optimized_attention_for_device(device, mask=False, small_input=False): if small_input: if model_management.pytorch_attention_enabled(): - return attention_pytorch #TODO: need to confirm but this is probably slightly faster for small inputs in all cases + return attention_pytorch # TODO: need to confirm but this is probably slightly faster for small inputs in all cases else: return attention_basic @@ -493,7 +511,7 @@ class BasicTransformerBlock(nn.Module): self.disable_self_attn = disable_self_attn self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout, - context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn + context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations) if disable_temporal_crossattention: @@ -507,7 +525,7 @@ class BasicTransformerBlock(nn.Module): context_dim_attn2 = context_dim self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2, - heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none + heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device) @@ -641,6 +659,7 @@ class SpatialTransformer(nn.Module): Finally, reshape to image NEW: use_linear for more efficiency instead of the 1x1 convs """ + def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None, disable_self_attn=False, use_linear=False, @@ -653,23 +672,23 @@ class SpatialTransformer(nn.Module): self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) if not use_linear: self.proj_in = operations.Conv2d(in_channels, - inner_dim, - kernel_size=1, - stride=1, - padding=0, dtype=dtype, device=device) + inner_dim, + kernel_size=1, + stride=1, + padding=0, dtype=dtype, device=device) else: self.proj_in = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.transformer_blocks = nn.ModuleList( [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations) - for d in range(depth)] + for d in range(depth)] ) if not use_linear: - self.proj_out = operations.Conv2d(inner_dim,in_channels, - kernel_size=1, - stride=1, - padding=0, dtype=dtype, device=device) + self.proj_out = operations.Conv2d(inner_dim, in_channels, + kernel_size=1, + stride=1, + padding=0, dtype=dtype, device=device) else: self.proj_out = operations.Linear(in_channels, inner_dim, dtype=dtype, device=device) self.use_linear = use_linear @@ -699,27 +718,27 @@ class SpatialTransformer(nn.Module): class SpatialVideoTransformer(SpatialTransformer): def __init__( - self, - in_channels, - n_heads, - d_head, - depth=1, - dropout=0.0, - use_linear=False, - context_dim=None, - use_spatial_context=False, - timesteps=None, - merge_strategy: str = "fixed", - merge_factor: float = 0.5, - time_context_dim=None, - ff_in=False, - checkpoint=False, - time_depth=1, - disable_self_attn=False, - disable_temporal_crossattention=False, - max_time_embed_period: int = 10000, - attn_precision=None, - dtype=None, device=None, operations=ops + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + use_linear=False, + context_dim=None, + use_spatial_context=False, + timesteps=None, + merge_strategy: str = "fixed", + merge_factor: float = 0.5, + time_context_dim=None, + ff_in=False, + checkpoint=False, + time_depth=1, + disable_self_attn=False, + disable_temporal_crossattention=False, + max_time_embed_period: int = 10000, + attn_precision=None, + dtype=None, device=None, operations=ops ): super().__init__( in_channels, @@ -785,13 +804,13 @@ class SpatialVideoTransformer(SpatialTransformer): ) def forward( - self, - x: torch.Tensor, - context: Optional[torch.Tensor] = None, - time_context: Optional[torch.Tensor] = None, - timesteps: Optional[int] = None, - image_only_indicator: Optional[torch.Tensor] = None, - transformer_options={} + self, + x: torch.Tensor, + context: Optional[torch.Tensor] = None, + time_context: Optional[torch.Tensor] = None, + timesteps: Optional[int] = None, + image_only_indicator: Optional[torch.Tensor] = None, + transformer_options={} ) -> torch.Tensor: _, _, h, w = x.shape x_in = x @@ -801,7 +820,7 @@ class SpatialVideoTransformer(SpatialTransformer): if self.use_spatial_context: assert ( - context.ndim == 3 + context.ndim == 3 ), f"n dims of spatial context should be 3 but are {context.ndim}" if time_context is None: @@ -830,7 +849,7 @@ class SpatialVideoTransformer(SpatialTransformer): emb = emb[:, None, :] for it_, (block, mix_block) in enumerate( - zip(self.transformer_blocks, self.time_stack) + zip(self.transformer_blocks, self.time_stack) ): transformer_options["block_index"] = it_ x = block( @@ -844,7 +863,7 @@ class SpatialVideoTransformer(SpatialTransformer): B, S, C = x_mix.shape x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps) - x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options + x_mix = mix_block(x_mix, context=time_context) # TODO: transformer_options x_mix = rearrange( x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps ) @@ -858,5 +877,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - - diff --git a/comfy/model_base.py b/comfy/model_base.py index c718cbbd5..3e167eebd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -108,7 +108,6 @@ class BaseModel(torch.nn.Module): operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if model_management.force_channels_last(): - # todo: ??? self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 5fae0855b..b566b1cf2 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -535,10 +535,6 @@ def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: for user_dir in Path(local_dir_root).iterdir(): for model_dir in user_dir.iterdir(): - try: - _hf_fs.resolve_path(str(user_dir / model_dir)) - except Exception as exc_info: - logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info) existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}") known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79e11f1e2..0f43a9b80 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -23,7 +23,7 @@ import sys import warnings from enum import Enum from threading import RLock -from typing import Literal, List, Sequence +from typing import Literal, List, Sequence, Final import psutil import torch @@ -128,6 +128,9 @@ def get_torch_device(): return torch.device("xpu", torch.xpu.current_device()) else: try: + # https://github.com/sayakpaul/diffusers-torchao/blob/bade7a6abb1cab9ef44782e6bcfab76d0237ae1f/inference/benchmark_image.py#L3 + # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer. + torch.set_float32_matmul_precision("high") return torch.device(torch.cuda.current_device()) except: warnings.warn("torch.cuda.current_device() did not return a device, returning a CPU torch device") @@ -319,7 +322,7 @@ try: except: logging.warning("Could not pick default device.") -current_loaded_models: List["LoadedModel"] = [] +current_loaded_models: Final[List["LoadedModel"]] = [] def module_size(module): @@ -974,6 +977,22 @@ def cast_to_device(tensor, device, dtype, copy=False): else: return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) +FLASH_ATTENTION_ENABLED = False +if not args.disable_flash_attn: + try: + import flash_attn + FLASH_ATTENTION_ENABLED = True + except ImportError: + pass + +SAGE_ATTENTION_ENABLED = False +if not args.disable_sage_attention: + try: + import sageattention + SAGE_ATTENTION_ENABLED = True + except ImportError: + pass + def xformers_enabled(): global directml_device @@ -986,6 +1005,30 @@ def xformers_enabled(): return False return XFORMERS_IS_AVAILABLE +def flash_attn_enabled(): + global directml_device + global cpu_state + if cpu_state != CPUState.GPU: + return False + if is_intel_xpu(): + return False + if directml_device: + return False + return FLASH_ATTENTION_ENABLED + +def sage_attention_enabled(): + global directml_device + global cpu_state + if cpu_state != CPUState.GPU: + return False + if is_intel_xpu(): + return False + if directml_device: + return False + if xformers_enabled(): + return False + return SAGE_ATTENTION_ENABLED + def xformers_enabled_vae(): enabled = xformers_enabled() diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 44e709e61..538939dcf 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -55,17 +55,13 @@ class ModelManageable(Protocol): def model_dtype(self) -> torch.dtype: return next(self.model.parameters()).dtype - def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: - self.patch_model(device_to=device_to, patch_weights=False) - return self.model - - def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module: + def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: ... - def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: + def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: """ Unloads the model by moving it to the offload device - :param offload_device: + :param device_to: :param unpatch_weights: :return: """ @@ -99,6 +95,20 @@ class ModelManageable(Protocol): def current_loaded_device(self) -> torch.device: return self.current_device + def get_model_object(self, name: str) -> torch.nn.Module: + from . import utils + return utils.get_attr(self.model, name) + + @property + def model_options(self) -> dict: + if not hasattr(self, "_model_options"): + setattr(self, "_model_options", {"transformer_options": {}}) + return getattr(self, "_model_options") + + @model_options.setter + def model_options(self, value): + setattr(self, "_model_options", value) + @dataclasses.dataclass class MemoryMeasurements: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c2102e239..213dd3eb4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,10 +27,11 @@ import torch.nn from . import model_management, lora from . import utils +from .comfy_types import UnetWrapperFunction from .float import stochastic_rounding from .model_base import BaseModel from .model_management_types import ModelManageable, MemoryMeasurements -from .comfy_types import UnetWrapperFunction + def string_to_seed(data): crc = 0xFFFFFFFF @@ -45,6 +46,7 @@ def string_to_seed(data): crc >>= 1 return crc ^ 0xFFFFFFFF + def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -106,7 +108,7 @@ class ModelPatcher(ModelManageable): self.backup = {} self.object_patches = {} self.object_patches_backup = {} - self.model_options = {"transformer_options": {}} + self._model_options = {"transformer_options": {}} self.model_size() self.load_device = load_device self.offload_device = offload_device @@ -115,6 +117,14 @@ class ModelPatcher(ModelManageable): self.ckpt_name = ckpt_name self._memory_measurements = MemoryMeasurements(self.model) + @property + def model_options(self) -> dict: + return self._model_options + + @model_options.setter + def model_options(self, value): + self._model_options = value + @property def model_device(self) -> torch.device: return self._memory_measurements.device @@ -145,7 +155,7 @@ class ModelPatcher(ModelManageable): n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() - n.model_options = copy.deepcopy(self.model_options) + n._model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup return n @@ -260,6 +270,11 @@ class ModelPatcher(ModelManageable): self.model_options["model_function_wrapper"] = wrap_func.to(device) def model_dtype(self): + # this pokes into the internals of diffusion model a little bit + # todo: the base model isn't going to be aware that its diffusion model is patched this way + if isinstance(self.model, BaseModel): + diffusion_model = self.get_model_object("diffusion_model") + return diffusion_model.dtype if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -293,7 +308,7 @@ class ModelPatcher(ModelManageable): if filter_prefix is not None: if not k.startswith(filter_prefix): continue - bk = self.backup.get(k, None) + bk: torch.nn.Module | None = self.backup.get(k, None) if bk is not None: weight = bk.weight else: @@ -494,7 +509,7 @@ class ModelPatcher(ModelManageable): if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: for key in [weight_key, bias_key]: - bk = self.backup.get(key, None) + bk: torch.nn.Module | None = self.backup.get(key, None) if bk is not None: if bk.inplace_update: utils.copy_to_param(self.model, key, bk.weight) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 7018b2170..bebf1e3a1 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -538,14 +538,16 @@ class DiffusersLoader: paths += get_huggingface_repo_list() paths = list(frozenset(paths)) - return {"required": {"model_path": (paths,), }} + return {"required": {"model_path": (paths,), + "weight_dtype": (FLUX_WEIGHT_DTYPES,) + }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" CATEGORY = "advanced/loaders" - def load_checkpoint(self, model_path, output_vae=True, output_clip=True): + def load_checkpoint(self, model_path, output_vae=True, output_clip=True,weight_dtype:str="default"): for search_path in folder_paths.get_folder_paths("diffusers"): if os.path.exists(search_path): path = os.path.join(search_path, model_path) @@ -556,7 +558,8 @@ class DiffusersLoader: with comfy_tqdm(): model_path = snapshot_download(model_path) - return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + model_options = get_model_options_for_dtype(weight_dtype) + return diffusers_load.load_diffusers(model_path, output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings"), model_options=model_options) class unCLIPCheckpointLoader: @@ -875,6 +878,14 @@ class ControlNetApplyAdvanced: out.append(c) return (out[0], out[1]) +def get_model_options_for_dtype(weight_dtype): + model_options = {} + if weight_dtype == "fp8_e4m3fn": + model_options["dtype"] = torch.float8_e4m3fn + elif weight_dtype == "fp8_e5m2": + model_options["dtype"] = torch.float8_e5m2 + return model_options + class UNETLoader: @classmethod @@ -888,16 +899,14 @@ class UNETLoader: CATEGORY = "advanced/loaders" def load_unet(self, unet_name, weight_dtype): - model_options = {} - if weight_dtype == "fp8_e4m3fn": - model_options["dtype"] = torch.float8_e4m3fn - elif weight_dtype == "fp8_e5m2": - model_options["dtype"] = torch.float8_e5m2 + model_options = get_model_options_for_dtype(weight_dtype) unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS) model = sd.load_diffusion_model(unet_path, model_options=model_options) return (model,) + + class CLIPLoader: @classmethod def INPUT_TYPES(s): diff --git a/comfy/utils.py b/comfy/utils.py index 10cc7dddf..c39343d8b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -19,15 +19,19 @@ from __future__ import annotations import contextlib import itertools +import json import logging import math +import os import random import struct import sys import warnings from contextlib import contextmanager +from pathlib import Path from typing import Optional, Any +import accelerate import numpy as np import safetensors.torch import torch @@ -55,13 +59,27 @@ def _get_progress_bar_enabled(): setattr(sys.modules[__name__], 'PROGRESS_BAR_ENABLED', property(_get_progress_bar_enabled)) -def load_torch_file(ckpt, safe_load=False, device=None): +def load_torch_file(ckpt: str, safe_load=False, device=None): if device is None: device = torch.device("cpu") if ckpt is None: raise FileNotFoundError("the checkpoint was not found") if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): sd = safetensors.torch.load_file(ckpt, device=device.type) + elif ckpt.lower().endswith("index.json"): + # from accelerate + index_filename = ckpt + checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename) as f: + index = json.loads(f.read()) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = sorted(list(set(index.values()))) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + sd: dict[str, torch.Tensor] = {} + for checkpoint_file in checkpoint_files: + sd.update(safetensors.torch.load_file(str(checkpoint_file), device=device.type)) else: if safe_load: if not 'weights_only' in torch.load.__code__.co_varnames: diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index 802856829..23ffb8474 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -1,8 +1,11 @@ import logging import torch +from torch.nn import LayerNorm +from comfy import model_management from comfy.model_patcher import ModelPatcher +from comfy.nodes.package_typing import CustomNode, InputTypes DIFFUSION_MODEL = "diffusion_model" @@ -47,6 +50,65 @@ class TorchCompileModel: return model, +class QuantizeModel(CustomNode): + @classmethod + def INPUT_TYPES(cls) -> InputTypes: + return { + "required": { + "model": ("MODEL", {}), + "strategy": (["torchao", "quanto"], {"default": "torchao"}) + } + } + + FUNCTION = "execute" + CATEGORY = "_for_testing" + EXPERIMENTAL = True + + RETURN_TYPES = ("MODEL",) + + def execute(self, model: ModelPatcher, strategy: str = "torchao"): + logging.warning(f"Quantizing {model} this way quantizes it in place, making it insuitable for cloning. All uses of this model will be quantized.") + logging.warning(f"Quantizing {model} will produce poor results due to Optimum's limitations") + model = model.clone() + unet = model.get_model_object("diffusion_model") + # todo: quantize quantizes in place, which is not desired + + # default exclusions + _unused_exclusions = { + "time_embedding.", + "add_embedding.", + "time_in.", + "txt_in.", + "vector_in.", + "img_in.", + "guidance_in.", + "final_layer.", + } + if strategy == "quanto": + from optimum.quanto import quantize, qint8 + exclusion_list = [ + name for name, module in unet.named_modules() if isinstance(module, LayerNorm) and module.weight is None + ] + quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) + _in_place_fixme = unet + elif strategy == "torchao": + from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight + model = model.clone() + unet = model.get_model_object("diffusion_model") + # todo: quantize quantizes in place, which is not desired + + # def filter_fn(module: torch.nn.Module, name: str): + # return any("weight" in name for name, _ in (module.named_parameters())) and all(exclusion not in name for exclusion in exclusions) + quantize_(unet, int8_dynamic_activation_int8_weight(), device=model_management.get_torch_device()) + _in_place_fixme = unet + else: + raise ValueError(f"unknown strategy {strategy}") + + model.add_object_patch("diffusion_model", _in_place_fixme) + return model, + + NODE_CLASS_MAPPINGS = { "TorchCompileModel": TorchCompileModel, + "QuantizeModel": QuantizeModel, } diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index 3d157a2b0..1a84b358f 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -88,18 +88,14 @@ class UpscaleModelManageable(ModelManageable): def model_dtype(self) -> torch.dtype: return next(self.model.parameters()).dtype - def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module: + def patch_model(self, device_to: torch.device | None = None, lowvram_model_memory: int = 0, load_weights: bool = True, force_patch_weights: bool = False) -> torch.nn.Module: self.model.to(device=device_to) return self.model - def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module: + def unpatch_model(self, device_to: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: self.model.to(device=device_to) return self.model - def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module: - self.model.to(device=offload_device) - return self.model - def __str__(self): if self.ckpt_name is not None: return f""