mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
Various fixes
This commit is contained in:
parent
abff4cf3d6
commit
0fc398a821
@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
from comfy import model_management
|
||||
|
||||
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@ -510,7 +512,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in ("scale", "enable_gqa")}
|
||||
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
|
||||
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
|
||||
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
|
||||
|
||||
@ -1295,7 +1295,7 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 4096:
|
||||
return TEModel.QWEN35_9B
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_31B
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
|
||||
@ -1004,7 +1004,11 @@ class Gemma4AudioProjector(Gemma4RMSNormProjector):
|
||||
# Tokenizer and Wrappers
|
||||
|
||||
class Gemma4_Tokenizer():
|
||||
tokenizer_json_data = None
|
||||
|
||||
def state_dict(self):
|
||||
if self.tokenizer_json_data is not None:
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
def _extract_mel_spectrogram(self, waveform, sample_rate):
|
||||
@ -1217,6 +1221,7 @@ class Gemma4SDTokenizer(Gemma4_Tokenizer, sd1_clip.SDTokenizer):
|
||||
embedding_size = 2560
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||||
self.tokenizer_json_data = tokenizer_json
|
||||
super().__init__(tokenizer_json, pad_with_end=False, embedding_size=self.embedding_size, embedding_key='gemma4', tokenizer_class=_Gemma4Tokenizer, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, start_token=2, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
|
||||
@ -162,12 +162,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput:
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, video=None, audio=None, thinking=False) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking)
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, video=video, audio=audio, thinking=thinking)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user