Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-11-04 10:17:26 -08:00
commit 772e768fe8
13 changed files with 382 additions and 97 deletions

View File

@ -22,6 +22,7 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/) - [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. - Smart memory management: can automatically run models on GPUs with as low as 1GB vram.

View File

@ -82,7 +82,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio
ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["checkpoints"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}), ModelPaths(["configs"], additional_absolute_directory_paths={get_package_as_path("comfy.configs")}, supported_extensions={".yaml"}),
ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["vae"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["clip"], supported_extensions=set(supported_pt_extensions)), ModelPaths(folder_names=["clip", "text_encoders"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["loras"], supported_extensions=set(supported_pt_extensions)),
ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True), ModelPaths(folder_names=["diffusion_models", "unet"], supported_extensions=set(supported_pt_extensions), folder_names_are_relative_directory_paths_too=True),
ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["clip_vision"], supported_extensions=set(supported_pt_extensions)),

View File

@ -16,7 +16,7 @@ class NoiseScheduleVP:
continuous_beta_0=0.1, continuous_beta_0=0.1,
continuous_beta_1=20., continuous_beta_1=20.,
): ):
"""Create a wrapper class for the forward SDE (VP type). r"""Create a wrapper class for the forward SDE (VP type).
*** ***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.

View File

@ -2,12 +2,16 @@
#adapted to ComfyUI #adapted to ComfyUI
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
from functools import partial
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -158,8 +162,10 @@ class ResBlock(nn.Module):
*, *,
affine: bool = True, affine: bool = True,
attn_block: Optional[nn.Module] = None, attn_block: Optional[nn.Module] = None,
padding_mode: str = "replicate",
causal: bool = True, causal: bool = True,
prune_bottleneck: bool = False,
padding_mode: str,
bias: bool = True,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -170,23 +176,23 @@ class ResBlock(nn.Module):
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels,
out_channels=channels, out_channels=channels // 2 if prune_bottleneck else channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
norm_fn(channels, affine=affine), norm_fn(channels, affine=affine),
nn.SiLU(inplace=True), nn.SiLU(inplace=True),
PConv3d( PConv3d(
in_channels=channels, in_channels=channels // 2 if prune_bottleneck else channels,
out_channels=channels, out_channels=channels,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
padding_mode=padding_mode, padding_mode=padding_mode,
bias=True, bias=bias,
# causal=causal, causal=causal,
), ),
) )
@ -206,6 +212,81 @@ class ResBlock(nn.Module):
return self.attn_block(x) return self.attn_block(x)
class Attention(nn.Module):
def __init__(
self,
dim: int,
head_dim: int = 32,
qkv_bias: bool = False,
out_bias: bool = True,
qk_norm: bool = True,
) -> None:
super().__init__()
self.head_dim = head_dim
self.num_heads = dim // head_dim
self.qk_norm = qk_norm
self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias)
self.out = nn.Linear(dim, dim, bias=out_bias)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""Compute temporal self-attention.
Args:
x: Input tensor. Shape: [B, C, T, H, W].
chunk_size: Chunk size for large tensors.
Returns:
x: Output tensor. Shape: [B, C, T, H, W].
"""
B, _, T, H, W = x.shape
if T == 1:
# No attention for single frame.
x = x.movedim(1, -1) # [B, C, T, H, W] -> [B, T, H, W, C]
qkv = self.qkv(x)
_, _, x = qkv.chunk(3, dim=-1) # Throw away queries and keys.
x = self.out(x)
return x.movedim(-1, 1) # [B, T, H, W, C] -> [B, C, T, H, W]
# 1D temporal attention.
x = rearrange(x, "B C t h w -> (B h w) t C")
qkv = self.qkv(x)
# Input: qkv with shape [B, t, 3 * num_heads * head_dim]
# Output: x with shape [B, num_heads, t, head_dim]
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2)
if self.qk_norm:
q = F.normalize(q, p=2, dim=-1)
k = F.normalize(k, p=2, dim=-1)
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
assert x.size(0) == q.size(0)
x = self.out(x)
x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W)
return x
class AttentionBlock(nn.Module):
def __init__(
self,
dim: int,
**attn_kwargs,
) -> None:
super().__init__()
self.norm = norm_fn(dim)
self.attn = Attention(dim, **attn_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.attn(self.norm(x))
class CausalUpsampleBlock(nn.Module): class CausalUpsampleBlock(nn.Module):
def __init__( def __init__(
self, self,
@ -244,14 +325,9 @@ class CausalUpsampleBlock(nn.Module):
return x return x
def block_fn(channels, *, has_attention: bool = False, **block_kwargs): def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs):
assert has_attention is False #NOTE: if this is ever true add back the attention code. attn_block = AttentionBlock(channels) if has_attention else None
return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs)
attn_block = None #AttentionBlock(channels) if has_attention else None
return ResBlock(
channels, affine=True, attn_block=attn_block, **block_kwargs
)
class DownsampleBlock(nn.Module): class DownsampleBlock(nn.Module):
@ -288,8 +364,9 @@ class DownsampleBlock(nn.Module):
out_channels=out_channels, out_channels=out_channels,
kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction),
stride=(temporal_reduction, spatial_reduction, spatial_reduction), stride=(temporal_reduction, spatial_reduction, spatial_reduction),
# First layer in each block always uses replicate padding
padding_mode="replicate", padding_mode="replicate",
bias=True, bias=block_kwargs["bias"],
) )
) )
@ -382,7 +459,7 @@ class Decoder(nn.Module):
blocks = [] blocks = []
first_block = [ first_block = [
nn.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1))
] # Input layer. ] # Input layer.
# First set of blocks preserve channel count. # First set of blocks preserve channel count.
for _ in range(num_res_blocks[-1]): for _ in range(num_res_blocks[-1]):
@ -452,11 +529,165 @@ class Decoder(nn.Module):
return self.output_proj(x).contiguous() return self.output_proj(x).contiguous()
class LatentDistribution:
def __init__(self, mean: torch.Tensor, logvar: torch.Tensor):
"""Initialize latent distribution.
Args:
mean: Mean of the distribution. Shape: [B, C, T, H, W].
logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W].
"""
assert mean.shape == logvar.shape
self.mean = mean
self.logvar = logvar
def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None):
if temperature == 0.0:
return self.mean
if noise is None:
noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator)
else:
assert noise.device == self.mean.device
noise = noise.to(self.mean.dtype)
if temperature != 1.0:
raise NotImplementedError(f"Temperature {temperature} is not supported.")
# Just Gaussian sample with no scaling of variance.
return noise * torch.exp(self.logvar * 0.5) + self.mean
def mode(self):
return self.mean
class Encoder(nn.Module):
def __init__(
self,
*,
in_channels: int,
base_channels: int,
channel_multipliers: List[int],
num_res_blocks: List[int],
latent_dim: int,
temporal_reductions: List[int],
spatial_reductions: List[int],
prune_bottlenecks: List[bool],
has_attentions: List[bool],
affine: bool = True,
bias: bool = True,
input_is_conv_1x1: bool = False,
padding_mode: str,
):
super().__init__()
self.temporal_reductions = temporal_reductions
self.spatial_reductions = spatial_reductions
self.base_channels = base_channels
self.channel_multipliers = channel_multipliers
self.num_res_blocks = num_res_blocks
self.latent_dim = latent_dim
self.fourier_features = FourierFeatures()
ch = [mult * base_channels for mult in channel_multipliers]
num_down_blocks = len(ch) - 1
assert len(num_res_blocks) == num_down_blocks + 2
layers = (
[ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)]
if not input_is_conv_1x1
else [Conv1x1(in_channels, ch[0])]
)
assert len(prune_bottlenecks) == num_down_blocks + 2
assert len(has_attentions) == num_down_blocks + 2
block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias)
for _ in range(num_res_blocks[0]):
layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0]))
prune_bottlenecks = prune_bottlenecks[1:]
has_attentions = has_attentions[1:]
assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1
for i in range(num_down_blocks):
layer = DownsampleBlock(
ch[i],
ch[i + 1],
num_res_blocks=num_res_blocks[i + 1],
temporal_reduction=temporal_reductions[i],
spatial_reduction=spatial_reductions[i],
prune_bottleneck=prune_bottlenecks[i],
has_attention=has_attentions[i],
affine=affine,
bias=bias,
padding_mode=padding_mode,
)
layers.append(layer)
# Additional blocks.
for _ in range(num_res_blocks[-1]):
layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1]))
self.layers = nn.Sequential(*layers)
# Output layers.
self.output_norm = norm_fn(ch[-1])
self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False)
@property
def temporal_downsample(self):
return math.prod(self.temporal_reductions)
@property
def spatial_downsample(self):
return math.prod(self.spatial_reductions)
def forward(self, x) -> LatentDistribution:
"""Forward pass.
Args:
x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]
Returns:
means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1].
h = H // 8, w = W // 8, t - 1 = (T - 1) // 6
logvar: Shape: [B, latent_dim, t, h, w].
"""
assert x.ndim == 5, f"Expected 5D input, got {x.shape}"
x = self.fourier_features(x)
x = self.layers(x)
x = self.output_norm(x)
x = F.silu(x, inplace=True)
x = self.output_proj(x)
means, logvar = torch.chunk(x, 2, dim=1)
assert means.ndim == 5
assert logvar.shape == means.shape
assert means.size(1) == self.latent_dim
return LatentDistribution(means, logvar)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.encoder = None #TODO once the model releases self.encoder = Encoder(
in_channels=15,
base_channels=64,
channel_multipliers=[1, 2, 4, 6],
num_res_blocks=[3, 3, 4, 6, 3],
latent_dim=12,
temporal_reductions=[1, 2, 3],
spatial_reductions=[2, 2, 2],
prune_bottlenecks=[False, False, False, False, False],
has_attentions=[False, True, True, True, True],
affine=True,
bias=True,
input_is_conv_1x1=True,
padding_mode="replicate"
)
self.decoder = Decoder( self.decoder = Decoder(
out_channels=3, out_channels=3,
base_channels=128, base_channels=128,
@ -474,7 +705,7 @@ class VideoVAE(nn.Module):
) )
def encode(self, x): def encode(self, x):
return self.encoder(x) return self.encoder(x).mode()
def decode(self, x): def decode(self, x):
return self.decoder(x) return self.decoder(x)

View File

@ -395,6 +395,13 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
return out return out
if model_management.is_nvidia(): # pytorch 2.3 and up seem to have this issue.
SDP_BATCH_LIMIT = 2 ** 15
else:
# TODO: other GPUs ?
SDP_BATCH_LIMIT = 2 ** 31
def pytorch_style_decl(func): def pytorch_style_decl(func):
@wraps(func) @wraps(func)
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False): def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
@ -408,8 +415,15 @@ def pytorch_style_decl(func):
(q, k, v), (q, k, v),
) )
out = func(q, k, v, heads, mask=mask, attn_precision=attn_precision, skip_reshape=skip_reshape) if SDP_BATCH_LIMIT >= q.shape[0]:
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) out = func(q, k, v, heads=heads, mask=mask, attn_precision=attn_precision)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
else:
out = torch.empty((q.shape[0], q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
for i in range(0, q.shape[0], SDP_BATCH_LIMIT):
out[i: i + SDP_BATCH_LIMIT] = func(q[i: i + SDP_BATCH_LIMIT], k[i: i + SDP_BATCH_LIMIT], v[i: i + SDP_BATCH_LIMIT], heads=heads, mask=mask, attn_precision=attn_precision).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out
return wrapper return wrapper

View File

@ -1,8 +1,11 @@
import logging
import math
import torch
from . import supported_models, utils from . import supported_models, utils
from . import supported_models_base from . import supported_models_base
import math
import logging
import torch
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
@ -17,6 +20,7 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1 count += 1
return count return count
def calculate_transformer_depth(prefix, state_dict_keys, state_dict): def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None context_dim = None
use_linear_in_transformer = False use_linear_in_transformer = False
@ -32,10 +36,11 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None return None
def detect_unet_config(state_dict, key_prefix): def detect_unet_config(state_dict, key_prefix):
state_dict_keys = list(state_dict.keys()) state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: # mmdit model
unet_config = {} unet_config = {}
unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1] unet_config["in_channels"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[1]
patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2] patch_size = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[2]
@ -65,7 +70,7 @@ def detect_unet_config(state_dict, key_prefix):
if rms_qk in state_dict_keys: if rms_qk in state_dict_keys:
unet_config["qk_norm"] = "rms" unet_config["qk_norm"] = "rms"
unet_config["pos_embed_scaling_factor"] = None #unused for inference unet_config["pos_embed_scaling_factor"] = None # unused for inference
context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix) context_processor = '{}context_processor.layers.0.attn.qkv.weight'.format(key_prefix)
if context_processor in state_dict_keys: if context_processor in state_dict_keys:
unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.') unet_config["context_processor_layers"] = count_blocks(state_dict_keys, '{}context_processor.layers.'.format(key_prefix) + '{}.')
@ -76,18 +81,18 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["x_block_self_attn_layers"].append(int(layer)) unet_config["x_block_self_attn_layers"].append(int(layer))
return unet_config return unet_config
if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: #stable cascade if '{}clf.1.weight'.format(key_prefix) in state_dict_keys: # stable cascade
unet_config = {} unet_config = {}
text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix) text_mapper_name = '{}clip_txt_mapper.weight'.format(key_prefix)
if text_mapper_name in state_dict_keys: if text_mapper_name in state_dict_keys:
unet_config['stable_cascade_stage'] = 'c' unet_config['stable_cascade_stage'] = 'c'
w = state_dict[text_mapper_name] w = state_dict[text_mapper_name]
if w.shape[0] == 1536: #stage c lite if w.shape[0] == 1536: # stage c lite
unet_config['c_cond'] = 1536 unet_config['c_cond'] = 1536
unet_config['c_hidden'] = [1536, 1536] unet_config['c_hidden'] = [1536, 1536]
unet_config['nhead'] = [24, 24] unet_config['nhead'] = [24, 24]
unet_config['blocks'] = [[4, 12], [12, 4]] unet_config['blocks'] = [[4, 12], [12, 4]]
elif w.shape[0] == 2048: #stage c full elif w.shape[0] == 2048: # stage c full
unet_config['c_cond'] = 2048 unet_config['c_cond'] = 2048
elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys: elif '{}clip_mapper.weight'.format(key_prefix) in state_dict_keys:
unet_config['stable_cascade_stage'] = 'b' unet_config['stable_cascade_stage'] = 'b'
@ -97,19 +102,19 @@ def detect_unet_config(state_dict, key_prefix):
unet_config['nhead'] = [-1, -1, 20, 20] unet_config['nhead'] = [-1, -1, 20, 20]
unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]] unet_config['blocks'] = [[2, 6, 28, 6], [6, 28, 6, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [3, 3, 2, 2]]
elif w.shape[-1] == 576: #stage b lite elif w.shape[-1] == 576: # stage b lite
unet_config['c_hidden'] = [320, 576, 1152, 1152] unet_config['c_hidden'] = [320, 576, 1152, 1152]
unet_config['nhead'] = [-1, 9, 18, 18] unet_config['nhead'] = [-1, 9, 18, 18]
unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]] unet_config['blocks'] = [[2, 4, 14, 4], [4, 14, 4, 2]]
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
return unet_config return unet_config
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: # stable audio dit
unet_config = {} unet_config = {}
unet_config["audio_model"] = "dit1.0" unet_config["audio_model"] = "dit1.0"
return unet_config return unet_config
if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: # aura flow dit
unet_config = {} unet_config = {}
unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1] unet_config["max_seq"] = state_dict['{}positional_encoding'.format(key_prefix)].shape[1]
unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1] unet_config["cond_seq_dim"] = state_dict['{}cond_seq_linear.weight'.format(key_prefix)].shape[1]
@ -119,12 +124,12 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["n_layers"] = double_layers + single_layers unet_config["n_layers"] = double_layers + single_layers
return unet_config return unet_config
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: # Hunyuan DiT
unet_config = {} unet_config = {}
unet_config["image_model"] = "hydit" unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0] unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2 if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: # DiT-g/2
unet_config["mlp_ratio"] = 4.3637 unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968: if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
unet_config["size_cond"] = True unet_config["size_cond"] = True
@ -132,7 +137,7 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["image_model"] = "hydit1" unet_config["image_model"] = "hydit1"
return unet_config return unet_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: # Flux
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
@ -149,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: # Genmo mochi preview
dit_config = {} dit_config = {}
dit_config["image_model"] = "mochi_preview" dit_config["image_model"] = "mochi_preview"
dit_config["depth"] = 48 dit_config["depth"] = 48
@ -176,7 +181,6 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0 dit_config["rope_theta"] = 10000.0
return dit_config return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None
@ -231,7 +235,7 @@ def detect_unet_config(state_dict, key_prefix):
block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys))) block_keys_output = sorted(list(filter(lambda a: a.startswith(prefix_output), state_dict_keys)))
if "{}0.op.weight".format(prefix) in block_keys: #new layer if "{}0.op.weight".format(prefix) in block_keys: # new layer
num_res_blocks.append(last_res_blocks) num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult) channel_mult.append(last_channel_mult)
@ -268,7 +272,6 @@ def detect_unet_config(state_dict, key_prefix):
else: else:
transformer_depth_output.append(0) transformer_depth_output.append(0)
num_res_blocks.append(last_res_blocks) num_res_blocks.append(last_res_blocks)
channel_mult.append(last_channel_mult) channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys: if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
@ -304,6 +307,7 @@ def detect_unet_config(state_dict, key_prefix):
return unet_config return unet_config
def model_config_from_unet_config(unet_config, state_dict=None): def model_config_from_unet_config(unet_config, state_dict=None):
for model_config in supported_models.models: for model_config in supported_models.models:
if model_config.matches(unet_config, state_dict): if model_config.matches(unet_config, state_dict):
@ -312,6 +316,7 @@ def model_config_from_unet_config(unet_config, state_dict=None):
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
return None return None
def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False): def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
unet_config = detect_unet_config(state_dict, unet_key_prefix) unet_config = detect_unet_config(state_dict, unet_key_prefix)
if unet_config is None: if unet_config is None:
@ -328,9 +333,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):
candidates = ["model.diffusion_model.", #ldm/sgm models candidates = ["model.diffusion_model.", # ldm/sgm models
"model.model.", #audio models "model.model.", # audio models
] ]
counts = {k: 0 for k in candidates} counts = {k: 0 for k in candidates}
for k in state_dict: for k in state_dict:
@ -343,7 +349,7 @@ def unet_prefix_from_state_dict(state_dict):
if counts[top] > 5: if counts[top] > 5:
return top return top
else: else:
return "model." #aura flow and others return "model." # aura flow and others
def convert_config(unet_config): def convert_config(unet_config):
@ -362,7 +368,7 @@ def convert_config(unet_config):
if isinstance(transformer_depth, int): if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth] transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None: if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1] transformer_depth_middle = transformer_depth[-1]
t_in = [] t_in = []
t_out = [] t_out = []
s = 1 s = 1
@ -470,10 +476,10 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SDXL_diffusers_ip2p = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 8, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10, 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 2, 2, 10, 10], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 10,
'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10], 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 0, 2, 2, 2, 10, 10, 10],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SSD_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
@ -482,41 +488,40 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, Segmind_Vega = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2], 'num_res_blocks': [2, 2, 2], 'transformer_depth': [0, 0, 1, 1, 2, 2], 'transformer_depth_output': [0, 0, 0, 1, 1, 1, 2, 2, 2],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5], 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6], 'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64, 'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD09_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'transformer_depth': [1, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True,
'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1], 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]} 'use_temporal_attention': False, 'use_temporal_resblock': False, 'disable_self_attentions': [True, False, False]}
SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, SD_XS = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1], 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [1, 1, 1],
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False, 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1], 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False} 'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint] supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
@ -530,28 +535,30 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
return convert_config(unet_config) return convert_config(unet_config)
return None return None
def model_config_from_diffusers_unet(state_dict): def model_config_from_diffusers_unet(state_dict):
unet_config = unet_config_from_diffusers_unet(state_dict) unet_config = unet_config_from_diffusers_unet(state_dict)
if unet_config is not None: if unet_config is not None:
return model_config_from_unet_config(unet_config) return model_config_from_unet_config(unet_config)
return None return None
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {} out_sd = {}
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux if 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: # AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
elif 'x_embedder.weight' in state_dict: # Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0] hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix) sd_map = utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: # SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
elif 'joint_transformer_blocks.0.attn.add_k_proj.weight' in state_dict: #AuraFlow
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
else: else:
return None return None

View File

@ -1062,7 +1062,7 @@ def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try: try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version <= (15, 0, 1): # black image bug on recent versions of macOS if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
upcast = True upcast = True
except: except:
pass pass

View File

@ -917,7 +917,7 @@ class UNETLoader:
class CLIPLoader: class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),), return {"required": { "clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
@ -938,15 +938,15 @@ class CLIPLoader:
else: else:
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}") logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
clip_path = get_or_download("clip", clip_name, KNOWN_CLIP_MODELS) clip_path = get_or_download("text_encoders", clip_name, KNOWN_CLIP_MODELS)
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
class DualCLIPLoader: class DualCLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (get_filename_list_with_downloadable("clip"),), "clip_name2": ( return {"required": { "clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": (
get_filename_list_with_downloadable("clip"),), get_filename_list_with_downloadable("text_encoders"),),
"type": (["sdxl", "sd3", "flux"], ), "type": (["sdxl", "sd3", "flux"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
@ -955,8 +955,8 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, type): def load_clip(self, clip_name1, clip_name2, type):
clip_path1 = get_or_download("clip", clip_name1) clip_path1 = get_or_download("text_encoders", clip_name1)
clip_path2 = get_or_download("clip", clip_name2) clip_path2 = get_or_download("text_encoders", clip_name2)
if type == "sdxl": if type == "sdxl":
clip_type = sd.CLIPType.STABLE_DIFFUSION clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
@ -1986,6 +1986,12 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageInvert": "Invert Image", "ImageInvert": "Invert Image",
"ImagePadForOutpaint": "Pad Image for Outpainting", "ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images", "ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",
"ImageSharpen": "Image Sharpen",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
# _for_testing # _for_testing
"VAEDecodeTiled": "VAE Decode (Tiled)", "VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",

View File

@ -183,6 +183,7 @@ class VAE:
self.downscale_ratio = 8 self.downscale_ratio = 8
self.upscale_ratio = 8 self.upscale_ratio = 8
self.latent_channels = 4 self.latent_channels = 4
self.latent_dim = 2
self.output_channels = 3 self.output_channels = 3
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
@ -252,16 +253,22 @@ class VAE:
self.output_channels = 2 self.output_channels = 2
self.upscale_ratio = 2048 self.upscale_ratio = 2048
self.downscale_ratio = 2048 self.downscale_ratio = 2048
self.latent_dim = 1
self.process_output = lambda audio: audio self.process_output = lambda audio: audio
self.process_input = lambda audio: audio self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd: #genmo mochi vae elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight": #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd: if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = utils.state_dict_prefix_replace(sd, {"": "decoder."}) sd = utils.state_dict_prefix_replace(sd, {"": "decoder."})
if "layers.4.layers.1.attn_block.attn.qkv.weight" in sd:
sd = utils.state_dict_prefix_replace(sd, {"": "encoder."})
self.first_stage_model = VideoVAE() self.first_stage_model = VideoVAE()
self.latent_channels = 12 self.latent_channels = 12
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -374,16 +381,21 @@ class VAE:
def encode(self, pixel_samples): def encode(self, pixel_samples):
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
free_memory = model_management.get_free_memory(self.device) free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used)) batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = torch.empty((pixel_samples.shape[0], self.latent_channels) + tuple(map(lambda a: a // self.downscale_ratio, pixel_samples.shape[2:])), device=self.output_device) samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
samples[x:x + batch_number] = self.first_stage_model.encode(pixels_in).to(self.output_device).float() out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION as e:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")

View File

@ -12,7 +12,7 @@ class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
class MochiT5XXL(sd1_clip.SD1ClipModel): class MochiT5XXL(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="t5xxl", clip_model=T5XXLModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):

View File

@ -388,10 +388,18 @@ MMDIT_MAP_BLOCK = {
("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"), ("context_block.mlp.fc1.weight", "ff_context.net.0.proj.weight"),
("context_block.mlp.fc2.bias", "ff_context.net.2.bias"), ("context_block.mlp.fc2.bias", "ff_context.net.2.bias"),
("context_block.mlp.fc2.weight", "ff_context.net.2.weight"), ("context_block.mlp.fc2.weight", "ff_context.net.2.weight"),
("context_block.attn.ln_q.weight", "attn.norm_added_q.weight"),
("context_block.attn.ln_k.weight", "attn.norm_added_k.weight"),
("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"), ("x_block.adaLN_modulation.1.bias", "norm1.linear.bias"),
("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"), ("x_block.adaLN_modulation.1.weight", "norm1.linear.weight"),
("x_block.attn.proj.bias", "attn.to_out.0.bias"), ("x_block.attn.proj.bias", "attn.to_out.0.bias"),
("x_block.attn.proj.weight", "attn.to_out.0.weight"), ("x_block.attn.proj.weight", "attn.to_out.0.weight"),
("x_block.attn.ln_q.weight", "attn.norm_q.weight"),
("x_block.attn.ln_k.weight", "attn.norm_k.weight"),
("x_block.attn2.proj.bias", "attn2.to_out.0.bias"),
("x_block.attn2.proj.weight", "attn2.to_out.0.weight"),
("x_block.attn2.ln_q.weight", "attn2.norm_q.weight"),
("x_block.attn2.ln_k.weight", "attn2.norm_k.weight"),
("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"), ("x_block.mlp.fc1.bias", "ff.net.0.proj.bias"),
("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"), ("x_block.mlp.fc1.weight", "ff.net.0.proj.weight"),
("x_block.mlp.fc2.bias", "ff.net.2.bias"), ("x_block.mlp.fc2.bias", "ff.net.2.bias"),
@ -422,6 +430,12 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset)) key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
k = "{}.attn2.".format(block_from)
qkv = "{}.x_block.attn2.qkv.{}".format(block_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, offset))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, offset, offset))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, offset * 2, offset))
for k in MMDIT_MAP_BLOCK: for k in MMDIT_MAP_BLOCK:
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])

View File

@ -3,9 +3,9 @@ import re
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.sd
import comfy.model_patcher import comfy.model_patcher
import comfy.samplers import comfy.samplers
import comfy.sd
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS from comfy.model_downloader import get_or_download, get_filename_list_with_downloadable, KNOWN_CLIP_MODELS
from comfy.nodes import base_nodes as nodes from comfy.nodes import base_nodes as nodes
@ -14,7 +14,7 @@ from comfy.nodes import base_nodes as nodes
class TripleCLIPLoader: class TripleCLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
filename_list = get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS) filename_list = get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS)
return {"required": {"clip_name1": (filename_list,), "clip_name2": (filename_list,), "clip_name3": (filename_list,) return {"required": {"clip_name1": (filename_list,), "clip_name2": (filename_list,), "clip_name3": (filename_list,)
}} }}
@ -24,9 +24,9 @@ class TripleCLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name1, clip_name2, clip_name3): def load_clip(self, clip_name1, clip_name2, clip_name3):
clip_path1 = get_or_download("clip", clip_name1, KNOWN_CLIP_MODELS) clip_path1 = get_or_download("text_encoders", clip_name1, KNOWN_CLIP_MODELS)
clip_path2 = get_or_download("clip", clip_name2, KNOWN_CLIP_MODELS) clip_path2 = get_or_download("text_encoders", clip_name2, KNOWN_CLIP_MODELS)
clip_path3 = get_or_download("clip", clip_name3, KNOWN_CLIP_MODELS) clip_path3 = get_or_download("text_encoders", clip_name3, KNOWN_CLIP_MODELS)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings")) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,) return (clip,)