This commit is contained in:
kijai 2026-04-12 16:53:12 +03:00
parent 6b803abe5a
commit 17ccff25be
2 changed files with 41 additions and 52 deletions

View File

@ -1400,17 +1400,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_4_E4B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E4B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_4_E2B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E2B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_E2BTokenizerWrapper
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_4_31B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_31B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_31BTokenizerWrapper
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))

View File

@ -70,6 +70,7 @@ class Gemma4_31B_Config(Gemma4Config):
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
hidden_size_per_layer_input: int = 0
num_kv_shared_layers: int = 0
audio_config = None
vision_config = GEMMA4_VISION_31B_CONFIG
@ -82,10 +83,6 @@ def _apply_rotary_pos_emb(x, freqs_cis):
out[..., half:] += x[..., :half] * sin[..., half:]
return out
def _apply_rope_gemma(xq, xk, freqs_cis):
return _apply_rotary_pos_emb(xq, freqs_cis), _apply_rotary_pos_emb(xk, freqs_cis)
class Gemma4Attention(nn.Module):
def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
super().__init__()
@ -138,7 +135,8 @@ class Gemma4Attention(nn.Module):
xv = rms_norm(xv, fused=False)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
xq, xk = _apply_rope_gemma(xq, xk, freqs_cis=freqs_cis)
xq = _apply_rotary_pos_emb(xq, freqs_cis)
xk = _apply_rotary_pos_emb(xk, freqs_cis)
present_key_value = None
if past_key_value is not None:
@ -445,13 +443,6 @@ class Gemma4AudioMixin:
return None, None
class Gemma4_E4B(Gemma4AudioMixin, Gemma4Base):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4Config(**config_dict), dtype, device, operations)
self._init_audio(self.model.config, dtype, device, operations)
# Vision Encoder
def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
@ -559,8 +550,8 @@ class Gemma4VisionAttention(nn.Module):
self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
self.q_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.k_norm = RMSNorm(self.head_dim, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
def forward(self, x, freqs, attention_mask=None):
batch_size, seq_length, _ = x.shape
@ -587,7 +578,7 @@ class Gemma4VisionLayer(nn.Module):
super().__init__()
self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
hidden = config["hidden_size"]
self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
@ -741,7 +732,7 @@ class Gemma4AudioConvSubsampler(nn.Module):
"""2D convolution subsampling for audio features"""
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
eps = config.get("rms_norm_eps", 1e-6)
eps = config["rms_norm_eps"]
self.layer0 = nn.ModuleDict({
'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
@ -777,10 +768,10 @@ class Gemma4AudioFeedForward(nn.Module):
super().__init__()
hidden_size = config["hidden_size"]
intermediate_size = config.get("intermediate_size", hidden_size * 4)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
self.post_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.post_layer_scale = config.get("residual_weight", 0.5)
def forward(self, x):
@ -921,12 +912,12 @@ class Gemma4AudioLConv1d(nn.Module):
super().__init__()
hidden_size = config["hidden_size"]
conv_kernel_size = config.get("conv_kernel_size", 5)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
# Causal conv: left-pad only
self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
self.conv_norm = RMSNorm(hidden_size, eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
def forward(self, x):
@ -949,7 +940,7 @@ class Gemma4AudioLayer(nn.Module):
super().__init__()
self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
norm_kwargs = dict(eps=config.get("rms_norm_eps", 1e-6), device=device, dtype=dtype, fused=False)
norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype, fused=False)
hidden_size = config["hidden_size"]
self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
@ -1239,10 +1230,6 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer):
class Gemma4Model(sd1_clip.SDClipModel):
model_class = None
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
self.dtypes = set()
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=self.model_class, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -1274,7 +1261,7 @@ class Gemma4Model(sd1_clip.SDClipModel):
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids)
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=Gemma4_E4B):
def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None):
clip_model = type('Gemma4Model_', (Gemma4Model,), {'model_class': model_class})
class Gemma4TEModel_(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
@ -1287,20 +1274,27 @@ def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=Ge
return Gemma4TEModel_
# Variants: config + model_class + embedding_size
class Gemma4_E2B(Gemma4AudioMixin, Gemma4Base):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4_E2B_Config(**config_dict), dtype, device, operations)
self._init_audio(self.model.config, dtype, device, operations)
# Variants
class Gemma4_31B(Gemma4Base):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(Gemma4_31B_Config(**config_dict), dtype, device, operations)
def _make_variant(config_cls):
audio = config_cls.audio_config is not None
bases = (Gemma4AudioMixin, Gemma4Base) if audio else (Gemma4Base,)
class Variant(*bases):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self._init_model(config_cls(**config_dict), dtype, device, operations)
if audio:
self._init_audio(self.model.config, dtype, device, operations)
embedding_size = config_cls.hidden_size
if embedding_size != Gemma4SDTokenizer.embedding_size:
tok_cls = type('T', (Gemma4SDTokenizer,), {'embedding_size': embedding_size})
class Tokenizer(Gemma4Tokenizer):
tokenizer_class = tok_cls
Variant.tokenizer = Tokenizer
else:
Variant.tokenizer = Gemma4Tokenizer
return Variant
class Gemma4_E2BTokenizerWrapper(Gemma4Tokenizer):
tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 1536})
class Gemma4_31BTokenizerWrapper(Gemma4Tokenizer):
tokenizer_class = type('T', (Gemma4SDTokenizer,), {'embedding_size': 5376})
Gemma4_E4B = _make_variant(Gemma4Config)
Gemma4_E2B = _make_variant(Gemma4_E2B_Config)
Gemma4_31B = _make_variant(Gemma4_31B_Config)