Various fixes

This commit is contained in:
kijai 2026-04-12 18:11:07 +03:00
parent abff4cf3d6
commit 0fc398a821
4 changed files with 13 additions and 4 deletions

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
@ -510,7 +512,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")}
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)

View File

@ -1295,7 +1295,7 @@ def detect_te_model(sd):
if weight.shape[0] == 4096:
return TEModel.QWEN35_9B
if weight.shape[0] == 5120:
return TEModel.QWEN35_31B
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']

View File

@ -1004,7 +1004,11 @@ class Gemma4AudioProjector(Gemma4RMSNormProjector):
# Tokenizer and Wrappers
class Gemma4_Tokenizer():
tokenizer_json_data = None
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
def _extract_mel_spectrogram(self, waveform, sample_rate):
@ -1217,6 +1221,7 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer):
embedding_size = 2560
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data)

View File

@ -162,12 +162,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking)
class TextgenExtension(ComfyExtension):