Don't add template to qwen2.5vl when template is in prompt. (#10043)

Make the hunyuan image refiner template_end 36.
This commit is contained in:
comfyanonymous 2025-09-26 15:34:17 -07:00 committed by GitHub
parent cd66d72b46
commit 1e098d6132
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 35 additions and 19 deletions

View File

@ -63,7 +63,13 @@ class HunyuanImageTEModel(QwenImageTEModel):
self.byt5_small = None self.byt5_small = None
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs) tok_pairs = token_weight_pairs["qwen25_7b"][0]
template_end = -1
if tok_pairs[0][0] == 27:
if len(tok_pairs) > 36: # refiner prompt uses a fixed 36 template_end
template_end = 36
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=template_end)
if self.byt5_small is not None and "byt5" in token_weight_pairs: if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0] extra["conditioning_byt5small"] = out[0]

View File

@ -18,13 +18,22 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None: skip_template = False
if len(images) > 0: if text.startswith('<|im_start|>'):
llama_text = self.llama_template_images.format(text) skip_template = True
else: if text.startswith('<|start_header_id|>'):
llama_text = self.llama_template.format(text) skip_template = True
if skip_template:
llama_text = text
else: else:
llama_text = llama_template.format(text) if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens)) key_name = next(iter(tokens))
embed_count = 0 embed_count = 0
@ -47,22 +56,23 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs) out, pooled, extra = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen25_7b"][0] tok_pairs = token_weight_pairs["qwen25_7b"][0]
count_im_start = 0 count_im_start = 0
for i, v in enumerate(tok_pairs): if template_end == -1:
elem = v[0] for i, v in enumerate(tok_pairs):
if not torch.is_tensor(elem): elem = v[0]
if isinstance(elem, numbers.Integral): if not torch.is_tensor(elem):
if elem == 151644 and count_im_start < 2: if isinstance(elem, numbers.Integral):
template_end = i if elem == 151644 and count_im_start < 2:
count_im_start += 1 template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 3): if out.shape[1] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872: if tok_pairs[template_end + 1][0] == 872:
if tok_pairs[template_end + 2][0] == 198: if tok_pairs[template_end + 2][0] == 198:
template_end += 3 template_end += 3
out = out[:, template_end:] out = out[:, template_end:]