From e9aae31fa241a6a63a368800146ea91629d4e8c2 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 25 Nov 2025 15:41:45 -0800 Subject: [PATCH] Z Image model. (#10892) --- comfy/ldm/lumina/model.py | 219 +++++++------------- comfy/ldm/modules/diffusionmodules/mmdit.py | 6 +- comfy/model_base.py | 4 + comfy/model_detection.py | 29 ++- comfy/sd.py | 8 + comfy/text_encoders/llama.py | 31 +++ comfy/text_encoders/z_image.py | 48 +++++ 7 files changed, 196 insertions(+), 149 deletions(-) create mode 100644 comfy/text_encoders/z_image.py diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b4494a51d..c8643eb82 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -11,6 +11,7 @@ import comfy.ldm.common_dit from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension @@ -31,6 +32,7 @@ class JointAttention(nn.Module): n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, + out_bias: bool = False, operation_settings={}, ): """ @@ -59,7 +61,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -70,35 +72,6 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( self, x: torch.Tensor, @@ -134,8 +107,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: @@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module): norm_eps: float, qk_norm: bool, modulation=True, + z_image_modulation=False, + attn_out_bias=False, operation_settings={}, ) -> None: """ @@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module): super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( self, @@ -323,7 +308,7 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}): super().__init__() self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, @@ -340,10 +325,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -373,12 +363,16 @@ class NextDiT(nn.Module): n_heads: int = 32, n_kv_heads: Optional[int] = None, multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, + ffn_dim_multiplier: float = 4.0, norm_eps: float = 1e-5, qk_norm: bool = False, cap_feat_dim: int = 5120, axes_dims: List[int] = (16, 56, 56), axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, image_model=None, device=None, dtype=None, @@ -390,6 +384,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -411,6 +407,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -434,7 +431,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -457,18 +454,24 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads @@ -503,108 +506,42 @@ class NextDiT(nn.Module): bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device) + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) - - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) + padded_img_mask = None for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): @@ -627,7 +564,7 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 42f406f1a..0dc8fe789 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size diff --git a/comfy/model_base.py b/comfy/model_base.py index cad79ecbd..cc21b1de9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1114,9 +1114,13 @@ class Lumina2(BaseModel): if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) + cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) + return out class WAN21(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b2ba1459d..7afe4a798 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -416,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 diff --git a/comfy/sd.py b/comfy/sd.py index 14dd8944c..350fae92b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -52,6 +52,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.z_image import comfy.model_patcher import comfy.lora @@ -953,6 +954,8 @@ class TEModel(Enum): GEMMA_3_4B = 13 MISTRAL3_24B = 14 MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -985,6 +988,8 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + if 'model.layers.0.self_attn.q_norm.weight' in sd: + return TEModel.QWEN3_4B weight = sd['model.layers.0.post_attention_layernorm.weight'] if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: @@ -1110,6 +1115,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index d47ed27bc..cd4b5f76c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -78,6 +78,28 @@ class Qwen25_3BConfig: rope_scale = None final_norm: bool = True +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + @dataclass class Qwen25_7BVLI_Config: vocab_size: int = 152064 @@ -511,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..bb9273b20 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,48 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_