mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 17:37:39 +08:00
125 lines
5.7 KiB
Python
125 lines
5.7 KiB
Python
import torch
|
|
|
|
from comfy import sd1_clip
|
|
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
|
import comfy.text_encoders.llama
|
|
|
|
|
|
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
|
"""Gemma-2-2b-it text encoder for PixelDiT.
|
|
|
|
Uses the FINAL hidden state (layer='last')
|
|
"""
|
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
|
super().__init__(
|
|
device=device, layer=layer, layer_idx=layer_idx,
|
|
textmodel_json_config={}, dtype=dtype,
|
|
special_tokens={"start": 2, "pad": 0},
|
|
layer_norm_hidden_state=False,
|
|
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
|
enable_attention_masks=attention_mask,
|
|
return_attention_masks=attention_mask,
|
|
model_options=model_options,
|
|
)
|
|
|
|
|
|
_PIXELDIT_CHI_PROMPT = (
|
|
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
|
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
|
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
|
"and spatial relationships to create vivid and concrete scenes.\n"
|
|
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
|
"overcomplicating.\n"
|
|
"Here are examples of how to transform or refine prompts:\n"
|
|
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
|
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
|
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
|
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
|
"passing by towering glass skyscrapers.\n"
|
|
"Please generate only the enhanced description for the prompt below and avoid including any "
|
|
"additional commentary or evaluations:\n"
|
|
"User Prompt: "
|
|
)
|
|
|
|
_PIXELDIT_MAX_LENGTH = 300
|
|
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
|
|
|
|
|
def _build_padded_tokens(combined_text: str, spiece_tokenizer, pad_id: int, chi_token_count: int):
|
|
# Right-pad to chi_token_count + 300 - 2 (matches upstream's max_length_all).
|
|
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
|
ids = spiece_tokenizer(combined_text)["input_ids"]
|
|
if len(ids) > max_length_all:
|
|
ids = ids[:max_length_all]
|
|
elif len(ids) < max_length_all:
|
|
ids = ids + [pad_id] * (max_length_all - len(ids))
|
|
return ids
|
|
|
|
|
|
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
|
"""Gemma-2-2b-it tokenizer that prepends PixelDiT's chi_prompt.
|
|
|
|
Empty text -> BOS + pad to 300. Text already starting with the chi_prompt
|
|
preamble is tokenized verbatim (override mirrors QwenImageTokenizer's
|
|
`<|im_start|>` detection). Else chi_prompt is prepended.
|
|
"""
|
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
|
if tokenizer_data is None:
|
|
tokenizer_data = {}
|
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
|
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
|
|
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
|
spiece_tokenizer = self.gemma2_2b.tokenizer
|
|
pad_id = self.gemma2_2b.pad_token
|
|
|
|
if not (isinstance(text, str) and text.strip()):
|
|
ids = spiece_tokenizer("")["input_ids"]
|
|
ids = ids + [pad_id] * (_PIXELDIT_MAX_LENGTH - len(ids))
|
|
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
|
|
|
|
chi_token_count = len(spiece_tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
|
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
|
ids = _build_padded_tokens(combined, spiece_tokenizer, pad_id, chi_token_count)
|
|
return {"gemma2_2b": [[(t, 1.0) for t in ids]]}
|
|
|
|
def untokenize(self, token_weight_pair):
|
|
return self.gemma2_2b.untokenize(token_weight_pair)
|
|
|
|
def state_dict(self):
|
|
return self.gemma2_2b.state_dict()
|
|
|
|
|
|
class PixelDiTGemma2TE(LuminaModel):
|
|
"""Text encoder wrapper for PixelDiT.
|
|
|
|
Overrides `encode_token_weights` to perform PixelDiT's `select_index` step:
|
|
encode the full padded sequence (up to ~chi_prompt_tokens + 298), then
|
|
return `[BOS_emb] + last_299_embs` as the 300-position conditioning that
|
|
matches the diffusion model's learned `y_pos_embedding` positions.
|
|
"""
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
|
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
|
|
|
def encode_token_weights(self, token_weight_pairs):
|
|
result = super().encode_token_weights(token_weight_pairs)
|
|
cond, pooled = result[0], result[1]
|
|
extra = result[2] if len(result) > 2 else None
|
|
L = cond.shape[1]
|
|
if L > _PIXELDIT_MAX_LENGTH:
|
|
head = cond[:, :1]
|
|
tail = cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]
|
|
cond = torch.cat([head, tail], dim=1)
|
|
if extra is not None and "attention_mask" in extra:
|
|
am = extra["attention_mask"]
|
|
if am.dim() == 1:
|
|
am = am.unsqueeze(0)
|
|
if am.shape[-1] == L:
|
|
head_m = am[..., :1]
|
|
tail_m = am[..., -(_PIXELDIT_MAX_LENGTH - 1):]
|
|
extra = {**extra, "attention_mask": torch.cat([head_m, tail_m], dim=-1)}
|
|
if extra is not None:
|
|
return cond, pooled, extra
|
|
return cond, pooled
|