parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize
This commit is contained in:
kijai 2026-04-07 01:15:04 +03:00
parent 832753f497
commit 93e8635110
10 changed files with 613 additions and 433 deletions

View File

@ -150,7 +150,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@ -219,6 +224,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +299,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@ -500,8 +509,11 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +531,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -3,7 +3,15 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
def rms_norm(x, weight=None, eps=1e-6, fused=True):
if not fused:
orig_dtype = x.dtype
normed = x.float() * torch.pow(x.float().pow(2).mean(-1, keepdim=True) + eps, -0.5)
if weight is not None:
weight = comfy.model_management.cast_to(weight, dtype=torch.float32, device=x.device)
normed = normed * weight
return normed.to(orig_dtype)
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:

View File

@ -1230,6 +1230,8 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
GEMMA_4_E4B = 28
GEMMA_4_E2B = 29
GEMMA_4_31B = 30
def detect_te_model(sd):
@ -1255,8 +1257,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1280,7 +1286,7 @@ def detect_te_model(sd):
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_31B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
@ -1395,9 +1401,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model == TEModel.GEMMA_4_E4B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E4B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4Tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_4_E2B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_E2B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_E2BTokenizerWrapper
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_4_31B:
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=comfy.text_encoders.gemma4.Gemma4_31B)
clip_target.tokenizer = comfy.text_encoders.gemma4.Gemma4_31BTokenizerWrapper
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

File diff suppressed because it is too large Load Diff

View File

@ -358,18 +358,19 @@ class Gemma3_12B_Config:
stop_tokens = [1, 106]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None, fused=True):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
self.add = add
self.fused = fused
def forward(self, x: torch.Tensor):
w = self.weight
if self.add:
w = w + 1.0
return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
return comfy.ldm.common_dit.rms_norm(x, w, self.eps, fused=self.fused)
@ -497,7 +498,7 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -509,12 +510,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
@ -623,6 +624,10 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _gemma_embed_scale_hook(module, input, output):
return (output.to(module._embed_scale.dtype) * module._embed_scale).to(output.dtype)
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@ -637,10 +642,10 @@ class Llama2_(nn.Module):
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
self.embed_tokens.register_buffer("_embed_scale", torch.tensor(config.hidden_size ** 0.5, dtype=dtype or self.embed_tokens.weight.dtype), persistent=False)
self.embed_tokens.register_forward_hook(_gemma_embed_scale_hook)
else:
transformer = TransformerBlock
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -672,9 +677,6 @@ class Llama2_(nn.Module):
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = super().process_tokens(tokens, device)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -32,6 +32,7 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch (1 FPS recommended)."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
@ -43,9 +44,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, audio=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, audio=audio, skip_template=False, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, video=video, audio=audio, skip_template=False, min_length=1, thinking=thinking)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"