From 0fc398a821bc7362ea2812b6db3d8f42f50af7e3 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sun, 12 Apr 2026 18:11:07 +0300 Subject: [PATCH] Various fixes --- comfy/ldm/modules/attention.py | 6 +++++- comfy/sd.py | 2 +- comfy/text_encoders/gemma4.py | 5 +++++ comfy_extras/nodes_textgen.py | 4 ++-- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 43cecad7f..a68cb8439 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention from comfy import model_management +TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5) + if model_management.xformers_enabled(): import xformers import xformers.ops @@ -510,7 +512,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha mask = mask.unsqueeze(1) # Pass through extra SDPA kwargs (scale, enable_gqa) if provided - sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")} + # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above + sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",) + sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys} if SDP_BATCH_LIMIT >= b: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra) diff --git a/comfy/sd.py b/comfy/sd.py index 06f46211e..3c19a4bb6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1295,7 +1295,7 @@ def detect_te_model(sd): if weight.shape[0] == 4096: return TEModel.QWEN35_9B if weight.shape[0] == 5120: - return TEModel.QWEN35_31B + return TEModel.QWEN35_27B return TEModel.QWEN35_2B if "model.layers.0.post_attention_layernorm.weight" in sd: weight = sd['model.layers.0.post_attention_layernorm.weight'] diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py index 9c2004a46..9573cd427 100644 --- a/comfy/text_encoders/gemma4.py +++ b/comfy/text_encoders/gemma4.py @@ -1004,7 +1004,11 @@ class Gemma4AudioProjector(Gemma4RMSNormProjector): # Tokenizer and Wrappers class Gemma4_Tokenizer(): + tokenizer_json_data = None + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} return {} def _extract_mel_spectrogram(self, waveform, sample_rate): @@ -1217,6 +1221,7 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer): embedding_size = 2560 def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data) diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index 4235fd310..b4f793f9a 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -162,12 +162,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking) class TextgenExtension(ComfyExtension):