Z Image model. (#10892)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2025-11-25 15:41:45 -08:00 committed by GitHub
parent 0c18842acb
commit e9aae31fa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 196 additions and 149 deletions

View File

@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension import comfy.patcher_extension
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int, n_heads: int,
n_kv_heads: Optional[int], n_kv_heads: Optional[int],
qk_norm: bool, qk_norm: bool,
out_bias: bool = False,
operation_settings={}, operation_settings={},
): ):
""" """
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear( self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim, n_heads * self.head_dim,
dim, dim,
bias=False, bias=out_bias,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq) xq = self.q_norm(xq)
xk = self.k_norm(xk) xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep >= 1:
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float, norm_eps: float,
qk_norm: bool, qk_norm: bool,
modulation=True, modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={}, operation_settings={},
) -> None: ) -> None:
""" """
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward( self.feed_forward = FeedForward(
dim=dim, dim=dim,
hidden_dim=4 * dim, hidden_dim=dim,
multiple_of=multiple_of, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings, operation_settings=operation_settings,
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
self.adaLN_modulation = nn.Sequential( if z_image_modulation:
nn.SiLU(), self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(dim, 1024), min(dim, 256),
4 * dim, 4 * dim,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
), ),
) )
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward( def forward(
self, self,
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT. The final layer of NextDiT.
""" """
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__() super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm( self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size, hidden_size,
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(hidden_size, 1024), min(hidden_size, min_mod),
hidden_size, hidden_size,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32, n_heads: int = 32,
n_kv_heads: Optional[int] = None, n_kv_heads: Optional[int] = None,
multiple_of: int = 256, multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None, ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = False, qk_norm: bool = False,
cap_feat_dim: int = 5120, cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56), axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512), axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear( self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels, in_features=patch_size * patch_size * in_channels,
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps, norm_eps,
qk_norm, qk_norm,
modulation=True, modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
] ]
) )
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier, ffn_dim_multiplier,
norm_eps, norm_eps,
qk_norm, qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims self.axes_dims = axes_dims
self.axes_lens = axes_lens self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
@ -503,108 +506,42 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
dtype = x[0].dtype
if cap_mask is not None: if self.pad_tokens_multiple is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist() pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
else: cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask): cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x] B, C, H, W = x.shape
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max( H_tokens, W_tokens = H // pH, W // pW
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
max_cap_len = max(l_effective_cap_len) x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
max_img_len = max(l_effective_img_len) x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
for i in range(bsz): freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image padded_img_mask = None
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -627,7 +564,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features y: (N,) tensor of text tokens/features
""" """
t = self.t_embedder(t, dtype=x.dtype) # (N, D) t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size

View File

@ -1114,9 +1114,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum(): if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):

View File

@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512] if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1

View File

@ -52,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -953,6 +954,8 @@ class TEModel(Enum):
GEMMA_3_4B = 13 GEMMA_3_4B = 13
MISTRAL3_24B = 14 MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15 MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -985,6 +988,8 @@ def detect_te_model(sd):
if weight.shape[0] == 512: if weight.shape[0] == 512:
return TEModel.QWEN25_7B return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd: if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight'] weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120: if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd: if "model.layers.39.post_attention_layernorm.weight" in sd:
@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:

View File

@ -78,6 +78,28 @@ class Qwen25_3BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
vocab_size: int = 152064 vocab_size: int = 152064
@ -511,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_