Should be working now.

This commit is contained in:
Talmaj Marinc 2026-02-22 23:17:18 +01:00
parent 2e9d4e967b
commit ab708b4b40
2 changed files with 16 additions and 4 deletions

View File

@ -1784,6 +1784,10 @@ class LongCatImage(supported_models_base.BASE):
elif k.startswith("time_embed.timestep_embedder.linear_2."):
out_sd["time_in.out_layer." + k.split(".")[-1]] = v
elif k.startswith("norm_out.linear."):
# HF AdaLayerNormContinuous stores [scale | shift] but ComfyUI
# LastLayer expects [shift | scale], so swap the two halves.
half = v.shape[0] // 2
v = torch.cat([v[half:], v[:half]], dim=0)
out_sd["final_layer.adaLN_modulation.1." + k.split(".")[-1]] = v
elif k == "proj_out.weight" or k == "proj_out.bias":
out_sd["final_layer.linear." + k.split(".")[-1]] = v

View File

@ -3,7 +3,9 @@ import numbers
import torch
from comfy import sd1_clip
from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel
import logging
logger = logging.getLogger(__name__)
QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")]
QUOTE_PATTERN = "|".join(
@ -33,6 +35,10 @@ def split_quotation(prompt):
class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_length = 512
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
parts = split_quotation(text)
all_tokens = []
@ -45,11 +51,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"]
all_tokens.extend(ids)
max_len = self.max_length if self.max_length < 99999999 else 512
if len(all_tokens) > max_len:
all_tokens = all_tokens[:max_len]
if len(all_tokens) > self.max_length:
all_tokens = all_tokens[:self.max_length]
logger.warning(f"Truncated prompt to {self.max_length} tokens")
output = [(t, 1.0) for t in all_tokens]
# Pad to max length
self.pad_tokens(output, self.max_length - len(output))
return [output]
@ -113,7 +121,7 @@ class LongCatImageTEModel(sd1_clip.SD1ClipModel):
for i in range(len(tok_pairs) - 1, -1, -1):
elem = tok_pairs[i][0]
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
if elem == 151644:
if elem == 151645:
suffix_start = i
break