From 22f6e407328ae4645fb56b0bf2c72d1c5aefaf48 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Fri, 5 Jun 2026 20:53:38 +0300 Subject: [PATCH] Support Gemma4 12B --- comfy/sd.py | 9 +- comfy/text_encoders/gemma4.py | 246 +++++++++++++++++++++++++++++----- comfy/text_encoders/llama.py | 4 +- 3 files changed, 218 insertions(+), 41 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index a66ba1bfb..f9aa19490 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1353,6 +1353,7 @@ class TEModel(Enum): GEMMA_4_31B = 31 T5_GEMMA = 32 GPT_OSS_20B = 33 + GEMMA_4_12B = 34 def detect_te_model(sd): @@ -1382,6 +1383,9 @@ def detect_te_model(sd): if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.59.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_4_31B + # Gemma4 12B Unified: 48 layers, encoder-free; global layers drop v_proj (attention_k_eq_v). + if 'model.layers.47.self_attn.q_norm.weight' in sd and 'model.layers.5.self_attn.v_proj.weight' not in sd: + return TEModel.GEMMA_4_12B if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd: return TEModel.GEMMA_4_E4B if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd: @@ -1535,10 +1539,11 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) - elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): + elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B, TEModel.GEMMA_4_12B): variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, - TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model] + TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B, + TEModel.GEMMA_4_12B: comfy.text_encoders.gemma4.Gemma4_12B}[te_model] clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant) clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index f050061ed..21a69f805 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1,6 +1,9 @@ import torch import torch.nn as nn +import torchaudio.functional as AF +import torchvision.transforms.functional as TVF import numpy as np +from tokenizers import Tokenizer from dataclasses import dataclass import math @@ -21,6 +24,10 @@ GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_siz GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3} GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5} +# Encoder-free (gemma4_unified) multimodal embedders: raw patches/waveform projected directly into LM space. +GEMMA4_UNIFIED_VISION_CONFIG = {"model_patch_size": 48, "patch_size": 16, "pooling_kernel_size": 3, "mm_embed_dim": 3840, "mm_posemb_size": 1120, "output_proj_dims": 3840, "rms_norm_eps": 1e-6} +GEMMA4_UNIFIED_AUDIO_CONFIG = {"audio_samples_per_token": 640, "output_proj_dims": 640, "rms_norm_eps": 1e-6} + @dataclass class Gemma4Config: vocab_size: int = 262144 @@ -35,6 +42,9 @@ class Gemma4Config: transformer_type: str = "gemma4" head_dim = 256 global_head_dim = 512 + num_global_key_value_heads = None + attention_k_eq_v = False + vision_bidirectional = False rms_norm_add = False mlp_activation = "gelu_pytorch_tanh" qkv_bias = False @@ -72,12 +82,29 @@ class Gemma4_31B_Config(Gemma4Config): num_hidden_layers: int = 60 num_attention_heads: int = 32 num_key_value_heads: int = 16 + vision_bidirectional = True sliding_attention = [1024, 1024, 1024, 1024, 1024, False] hidden_size_per_layer_input: int = 0 num_kv_shared_layers: int = 0 audio_config = None vision_config = GEMMA4_VISION_31B_CONFIG +@dataclass +class Gemma4_12B_Config(Gemma4Config): + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + num_global_key_value_heads = 1 + attention_k_eq_v = True + vision_bidirectional = True + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + hidden_size_per_layer_input: int = 0 + num_kv_shared_layers: int = 0 + audio_config = GEMMA4_UNIFIED_AUDIO_CONFIG + vision_config = GEMMA4_UNIFIED_VISION_CONFIG + # unfused RoPE as addcmul_ RoPE diverges from reference code def _apply_rotary_pos_emb(x, freqs_cis): @@ -89,17 +116,18 @@ def _apply_rotary_pos_emb(x, freqs_cis): return out class Gemma4Attention(nn.Module): - def __init__(self, config, head_dim, device=None, dtype=None, ops=None): + def __init__(self, config, head_dim, num_kv_heads=None, k_eq_v=False, device=None, dtype=None, ops=None): super().__init__() self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads + self.num_kv_heads = num_kv_heads if num_kv_heads is not None else config.num_key_value_heads self.hidden_size = config.hidden_size self.head_dim = head_dim self.inner_size = self.num_heads * head_dim self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype) self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) - self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) + # k_eq_v: V reuses the K projection (no separate v_proj weight) + self.v_proj = None if k_eq_v else ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype) self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.q_norm = None @@ -133,7 +161,10 @@ class Gemma4Attention(nn.Module): shareable_kv = None else: xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) - xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + if self.v_proj is not None: + xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim) + else: + xv = xk # k_eq_v: V is the raw K projection (before k_norm/RoPE) if self.k_norm is not None: xk = self.k_norm(xk) xv = rms_norm(xv) @@ -186,7 +217,10 @@ class TransformerBlockGemma4(nn.Module): head_dim = config.head_dim if self.sliding_attention else config.global_head_dim - self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops) + # k_eq_v only on global layers, which then use num_global_key_value_heads + k_eq_v = config.attention_k_eq_v and not self.sliding_attention + num_kv_heads = config.num_global_key_value_heads if k_eq_v else config.num_key_value_heads + self.self_attn = Gemma4Attention(config, head_dim=head_dim, num_kv_heads=num_kv_heads, k_eq_v=k_eq_v, device=device, dtype=dtype, ops=ops) num_kv_shared = config.num_kv_shared_layers first_kv_shared = config.num_hidden_layers - num_kv_shared @@ -203,9 +237,9 @@ class TransformerBlockGemma4(nn.Module): self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype) self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype) self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) - self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) - else: - self.layer_scalar = None + + # layer_scalar exists on every gemma4 variant, independent of per-layer input + self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype)) def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None): sliding_window = None @@ -245,7 +279,7 @@ class TransformerBlockGemma4(nn.Module): x = residual + x if self.layer_scalar is not None: - x = x * self.layer_scalar + x = x * comfy.model_management.cast_to_device(self.layer_scalar, x.device, x.dtype) return x, present_key_value, shareable_kv @@ -334,6 +368,19 @@ class Gemma4Transformer(nn.Module): causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val) mask = mask + causal_mask if mask is not None else causal_mask + # Bidirectional attention within each image soft-token block (prefill only; text/audio stay causal). + if getattr(self.config, "vision_bidirectional", False) and past_len == 0 and embeds_info: + block_ids = torch.full((seq_len,), -1, dtype=torch.long, device=x.device) + group = 0 + for info in embeds_info: + if info.get("type") == "image": + start = info["index"] + block_ids[start:start + info["size"]] = group + group += 1 + if group > 0: + same_block = (block_ids[:, None] == block_ids[None, :]) & (block_ids[:, None] >= 0) + mask = mask.masked_fill(same_block, 0.0) + # Per-layer inputs per_layer_inputs = None if self.hidden_size_per_layer_input: @@ -441,6 +488,28 @@ class Gemma4AudioMixin: return None, None +class Gemma4UnifiedBase(Gemma4Base): + """Encoder-free multimodal Gemma4 (gemma4_unified, e.g. 12B): raw image patches and audio frames projected directly into LM space.""" + def _init_model(self, config, dtype, device, operations): + self.num_layers = config.num_hidden_layers + self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.vision_model = Gemma4UnifiedVisionEmbedder(config.vision_config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma4RMSNormProjector(config.vision_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations) + self.audio_projector = Gemma4RMSNormProjector(config.audio_config["output_proj_dims"], config.hidden_size, dtype=dtype, device=device, ops=operations) + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + pixels = embed.pop("data").movedim(-1, 1).to(device, dtype=self.dtype) # [B, H, W, C] -> [B, C, H, W], [0,1] + patches, positions = self.vision_model.patchify(pixels) + vision_out = self.vision_model(patches, positions) + return self.multi_modal_projector(vision_out), None + if embed["type"] == "audio": + audio = embed.pop("data").to(device, dtype=self.dtype) # [1, T, audio_samples_per_token] + return self.audio_projector(audio), None + return None, None + + # Vision Encoder def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None): @@ -713,6 +782,73 @@ class Gemma4MultiModalProjector(Gemma4RMSNormProjector): super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops) +# Encoder-free vision (gemma4_unified): raw merged pixel patches projected directly into LM space. + +def _patches_merge(patches, positions_xy, length): + patch_size = math.isqrt(patches.shape[-1] // 3) + k = math.isqrt(patches.shape[-2] // length) + batch = patches.shape[:-2] + + max_x = positions_xy[..., 0].max(dim=-1, keepdim=True)[0] + 1 + kidx = torch.div(positions_xy, k, rounding_mode="floor") + rem = torch.remainder(positions_xy, k) + order = rem[..., 0] + rem[..., 1] * k + k * k * kidx[..., 0] + k * max_x * kidx[..., 1] + perm = order.long().argsort(dim=-1) + + merged = patches.gather(-2, perm.unsqueeze(-1).expand_as(patches)) + merged = merged.reshape(*batch, length, k, k, patch_size, patch_size, 3) + merged = merged.permute(*range(len(batch)), -6, -5, -3, -4, -2, -1).reshape(*batch, length, (k * patch_size) ** 2 * 3) + + pos = positions_xy.float().gather(-2, perm.unsqueeze(-1).expand_as(positions_xy).long()) + pad = (positions_xy == -1).all(dim=-1, keepdim=True) + pos = torch.where(pad, positions_xy.float(), pos).reshape(*batch, length, k * k, 2) + pos = torch.div(pos, k, rounding_mode="floor").min(dim=-2)[0].to(torch.long) + return merged, pos + + +class Gemma4UnifiedVisionEmbedder(nn.Module): + """Encoder-free patch embedder (LN -> Dense -> LN -> +2D posemb -> LN); projection to text space is the separate multi_modal_projector.""" + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.patch_size = config["patch_size"] + self.pooling_kernel_size = config["pooling_kernel_size"] + patch_dim = config["model_patch_size"] ** 2 * 3 + mm_embed_dim = config["mm_embed_dim"] + self.patch_ln1 = ops.LayerNorm(patch_dim, device=device, dtype=dtype) + self.patch_dense = ops.Linear(patch_dim, mm_embed_dim, device=device, dtype=dtype) + self.patch_ln2 = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype) + self.pos_embedding = nn.Parameter(torch.empty(config["mm_posemb_size"], 2, mm_embed_dim, device=device, dtype=dtype)) + self.pos_norm = ops.LayerNorm(mm_embed_dim, device=device, dtype=dtype) + + def patchify(self, pixels): + """pixels: [B, C, H, W] in [0,1] -> merged patches [B, N, 6912], positions [B, N, 2].""" + ps, k = self.patch_size, self.pooling_kernel_size + out_patches, out_positions = [], [] + for img in pixels: + ph, pw = img.shape[-2] // ps, img.shape[-1] // ps + teacher = img.reshape(img.shape[0], ph, ps, pw, ps).permute(1, 3, 2, 4, 0).reshape(ph * pw, -1) + grid = torch.meshgrid(torch.arange(pw, device=img.device), torch.arange(ph, device=img.device), indexing="xy") + tpos = torch.stack(grid, dim=-1).reshape(teacher.shape[0], 2) + n_model = teacher.shape[0] // (k * k) + mp, mpos = _patches_merge(teacher.unsqueeze(0), tpos.unsqueeze(0), n_model) + out_patches.append(mp.squeeze(0)) + out_positions.append(mpos.squeeze(0)) + return torch.stack(out_patches), torch.stack(out_positions) + + def forward(self, pixel_values, image_position_ids): + x = self.patch_ln1(pixel_values) + x = self.patch_dense(x) + x = self.patch_ln2(x) + + clamped = image_position_ids.clamp(min=0).long() + valid = (image_position_ids != -1).to(x.dtype).unsqueeze(-1) + axes = torch.arange(2, device=image_position_ids.device) + pos = comfy.model_management.cast_to_device(self.pos_embedding, x.device, x.dtype) + pos_embs = (pos[clamped, axes] * valid).sum(-2) + x = x + pos_embs + return self.pos_norm(x) + + # Audio Encoder class Gemma4AudioConvSubsampler(nn.Module): @@ -998,25 +1134,35 @@ class Gemma4_Tokenizer(): return {"tokenizer_json": self.tokenizer_json_data} return {} - def _extract_mel_spectrogram(self, waveform, sample_rate): - """Extract 128-bin log mel spectrogram. - Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. - """ - # Mix to mono first, then resample to 16kHz + def _audio_token_count(self, num_samples): + # Default (E2B/E4B): mel frames after two stride-2 conv subsamples. + _fl = 320 # int(round(16000 * 20.0 / 1000.0)) + _hl = 160 # int(round(16000 * 10.0 / 1000.0)) + _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 + _t = _nmel + for _ in range(2): + _t = (_t + 2 - 3) // 2 + 1 + return min(_t, 750) + + @staticmethod + def _resample_16k(waveform, sample_rate): + """Mix to mono and resample to 16kHz. Kaiser params reproduce the reference (transformers + load_audio -> librosa/soxr_hq) to ~1e-12 MSE using only torchaudio.""" if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) if waveform.dim() == 1: waveform = waveform.unsqueeze(0) - audio = waveform.squeeze(0).float().numpy() + audio = waveform.float() if sample_rate != 16000: - # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match) - from scipy.signal import resample_poly, firwin - from math import gcd - g = gcd(sample_rate, 16000) - up, down = 16000 // g, sample_rate // g - L = max(up, down) - h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5)) - audio = resample_poly(audio, up, down, window=h).astype(np.float32) + audio = AF.resample(audio, sample_rate, 16000, resampling_method="sinc_interp_kaiser", + lowpass_filter_width=121, rolloff=0.9568384289091556, beta=21.01531462440614) + return audio.squeeze(0).contiguous() + + def _extract_audio_features(self, waveform, sample_rate): + """Default (E2B/E4B): 128-bin log mel spectrogram for the conformer audio encoder. + Uses numpy for FFT/matmul/log to produce bit-identical results with reference code. + """ + audio = self._resample_16k(waveform, sample_rate).numpy() n = len(audio) # Pad to multiple of 128, build sample-level mask @@ -1064,8 +1210,8 @@ class Gemma4_Tokenizer(): if audio is not None: waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000 - mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate) - audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T]) + feat, feat_mask = self._extract_audio_features(waveform, sample_rate) + audio_features = [(feat.unsqueeze(0), feat_mask.unsqueeze(0))] # ([1, T, D], [1, T]) # Process image/video frames is_video = video is not None @@ -1096,7 +1242,6 @@ class Gemma4_Tokenizer(): target_h = max(int(factor * h // side_mult) * side_mult, side_mult) target_w = max(int(factor * w // side_mult) * side_mult, side_mult) - import torchvision.transforms.functional as TVF for i in range(num_frames): # rescaling to match reference code s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8 @@ -1115,7 +1260,7 @@ class Gemma4_Tokenizer(): llama_text = llama_template.format(text) else: # Build template from modalities present - system = "<|turn>system\n<|think|>\n" if thinking else "" + system = "<|turn>system\n<|think|>\n\n" if thinking else "" media = "" if len(images) > 0: if is_video: @@ -1135,15 +1280,11 @@ class Gemma4_Tokenizer(): if len(audio_features) > 0: # Compute audio token count (always at 16kHz) num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1] - _fl = 320 # int(round(16000 * 20.0 / 1000.0)) - _hl = 160 # int(round(16000 * 10.0 / 1000.0)) - _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1 - _t = _nmel - for _ in range(2): - _t = (_t + 2 - 3) // 2 + 1 - n_audio_tokens = min(_t, 750) + n_audio_tokens = self._audio_token_count(num_samples) media += "<|audio>" + "<|audio|>" * n_audio_tokens + "" - llama_text = f"{system}<|turn>user\n{media}{text}\n<|turn>model\n" + # Non-thinking mode primes an empty thought channel so the model answers directly. + model_open = "" if thinking else "<|channel>thought\n" + llama_text = f"{system}<|turn>user\n{text}{media}\n<|turn>model\n{model_open}" text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) @@ -1178,7 +1319,6 @@ class Gemma4_Tokenizer(): class _Gemma4Tokenizer: """Tokenizer using the tokenizers (Gemma4 doesn't come with sentencepiece model)""" def __init__(self, tokenizer_json_bytes=None, **kwargs): - from tokenizers import Tokenizer if isinstance(tokenizer_json_bytes, torch.Tensor): tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) @@ -1224,6 +1364,30 @@ class Gemma4Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma4", tokenizer=self.tokenizer_class) +class Gemma4UnifiedSDTokenizer(Gemma4SDTokenizer): + """Encoder-free (gemma4_unified) audio: raw 16kHz waveform frames instead of mel spectrogram.""" + embedding_size = 3840 + + def _extract_audio_features(self, waveform, sample_rate): + audio = self._resample_16k(waveform, sample_rate) + spt = 640 # audio_samples_per_token (40ms at 16kHz) + pad = (-audio.shape[0]) % spt + if pad: + audio = torch.nn.functional.pad(audio, (0, pad)) + num_tokens = audio.shape[0] // spt + feats = audio[:num_tokens * spt].reshape(num_tokens, spt) + feats = feats[:750] # audio_seq_length cap (matches reference truncation, ~30s) + mask = torch.ones(feats.shape[0], dtype=torch.bool) + return feats, mask + + def _audio_token_count(self, num_samples): + return min((num_samples + 639) // 640, 750) + + +class Gemma4UnifiedTokenizer(Gemma4Tokenizer): + tokenizer_class = Gemma4UnifiedSDTokenizer + + # Model wrappers class Gemma4Model(sd1_clip.SDClipModel): model_class = None @@ -1256,7 +1420,7 @@ class Gemma4Model(sd1_clip.SDClipModel): expanded_idx += 1 initial_token_ids = [ids] input_ids = torch.tensor(initial_token_ids, device=self.execution_device) - return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, initial_tokens=initial_token_ids[0], presence_penalty=presence_penalty, initial_input_ids=input_ids, embeds_info=embeds_info) def gemma4_te(dtype_llama=None, llama_quantization_metadata=None, model_class=None): @@ -1296,3 +1460,11 @@ def _make_variant(config_cls): Gemma4_E4B = _make_variant(Gemma4Config) Gemma4_E2B = _make_variant(Gemma4_E2B_Config) Gemma4_31B = _make_variant(Gemma4_31B_Config) + + +# Gemma4 12B Unified: encoder-free multimodal, distinct base/tokenizer (not via _make_variant). +class Gemma4_12B(Gemma4UnifiedBase): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self._init_model(Gemma4_12B_Config(**config_dict), dtype, device, operations) +Gemma4_12B.tokenizer = Gemma4UnifiedTokenizer diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 5087228ca..7181c78d1 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -860,7 +860,7 @@ class BaseGenerate: torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) return past_key_values - def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None): + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, embeds_info=None): device = embeds.device if stop_tokens is None: @@ -887,7 +887,7 @@ class BaseGenerate: # Generation loop current_input_ids = initial_input_ids for step in tqdm(range(max_length), desc="Generating tokens"): - x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids) + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, embeds_info=(embeds_info if step == 0 else None)) logits = self.logits(x)[:, -1] next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) token_id = next_token[0].item()