ComfyUI/comfy/text_encoders/joyimage.py
huangfeice 5260e18cdf Add JoyImageEdit native model support
JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource),
Apache 2.0. This adds native ComfyUI support so it loads and runs like other
edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler ->
VAEDecode), with no diffusers dependency.

Architecture:
- Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a
  Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation,
  and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through
  comfy.ldm.modules.attention.optimized_attention.
- Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable
  Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin
  JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory)
  in joyimage.py that depends on it. text_dim 4096.
- VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new
  latent format.
- Edit conditioning: reuses the reference_latents mechanism. Reference and
  noise latents are stacked on a new n-slot dimension and rotated at the model
  boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out.
  Guidance-rescale is built into the CFG path.

Model wiring:
- model_base.JoyImage uses ModelType.FLOW with sampling_settings
  multiplier=1000 (the time embedding is trained on t in [0,1000]) and
  shift=1.5; FLOW's linear time_snr_shift matches the diffusers
  FlowMatchEuler sigma schedule.
- model_detection sniffs the transformer state-dict (double_blocks.*,
  condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage".
- supported_models.JoyImage and the CLIPLoader "joyimage" type register it.

User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py)
bucket-resizes the input image to the nearest 1024-base bucket, encodes the
prompt with the image, and emits both the conditioning and the bucketed image
so the same pixels feed VAEEncode and the negative encode (JoyImage requires
noise and reference latents to share spatial dims).
2026-06-17 18:53:36 +08:00

186 lines
8.4 KiB
Python

"""JoyImageEdit text encoder: Qwen3-VL multimodal stack feeding the JoyImageEdit DiT.
Plugs the generic Qwen3-VL stack from `comfy.text_encoders.qwen3_vl` into the
`SDClipModel` / `SD1ClipModel` contract, adding only the JoyImage-specific
templates, drop_idx, tokenizer wrapper, and `te()` factory.
"""
import os
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
from comfy.text_encoders.qwen3_vl import Qwen3VLBase
# Prompt templates for the text-only and image-conditioned modes. The
# image-conditioned template wraps the user text with a single
# `<|vision_start|><|image_pad|><|vision_end|>` block; this encoder supports one
# user turn per call.
JOYIMAGE_TEMPLATE_TEXT = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
)
JOYIMAGE_TEMPLATE_IMAGE = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
)
# Tokens 0..33 of either formatted template (system prompt + leading
# `<|im_start|>` of the user block) are stripped from the encoded output by
# JoyImageTEModel.encode_token_weights so that the kept tail begins at the
# `user` token (prefix[:34] decodes to the system block ending at the leading
# `<|im_start|>` of the user turn).
JOYIMAGE_DROP_IDX = 34
# Special-token ids from the JoyImage Qwen3-VL tokenizer (vocab is shared
# with Qwen2.5 / Qwen3 — vocab_size 151936).
IMAGE_PAD_TOKEN = 151655
PAD_TOKEN = 151643
class Qwen3VL8B_JoyImage(Qwen3VLBase):
"""Bind `Qwen3VLBase` to the JoyImage-specific config dict shape.
The JoyImage checkpoint follows the standard Qwen3-VL 8B text dims
(4096 / 36L / 32H / 8 kv / silu / qkv_bias=False, q/k_norm=gemma3) plus
interleaved 3D MRoPE with rope_dims=[24, 20, 20] and rope_theta=5e6 —
all defaults of `Qwen3VLConfig`. Vision tower uses the defaults of
`Qwen3VLVisionConfig` (1152/4304/4096/16H, 27 blocks, patch_size=16,
deepstack_visual_indexes=[8, 16, 24]).
"""
def __init__(self, config_dict, dtype, device, operations):
super().__init__(config_dict, dtype, device, operations)
class _JoyImageBaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
# Reuse the existing qwen25_tokenizer artefacts shipped with ComfyUI;
# the JoyImage tokenizer is the same vocab/merges as Qwen2.5/Qwen3
# (vocab_size 151936). The image-pad / vision-start / vision-end
# special tokens are present in that vocab.
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(
tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
embedding_size=4096, embedding_key="qwen3vl_8b", tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False,
max_length=99999999, min_length=1, pad_token=PAD_TOKEN, tokenizer_data=tokenizer_data,
)
class JoyImageTokenizer(sd1_clip.SD1Tokenizer):
"""JoyImageEdit tokenizer.
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
template when one or more image tensors are passed, otherwise the text-only
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
with an embedding marker so `SDClipModel.process_tokens` routes the image
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
template tokens are stripped downstream by
`JoyImageTEModel.encode_token_weights`.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="qwen3vl_8b", tokenizer=_JoyImageBaseTokenizer,
)
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
images=[], **kwargs):
if text.startswith("<|im_start|>"):
llama_text = text
elif llama_template is not None:
llama_text = llama_template.format(text)
elif len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
tokens = super().tokenize_with_weights(
llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == IMAGE_PAD_TOKEN:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count],
"original_type": "image"},) + r[i][1:]
embed_count += 1
if embed_count != len(images):
raise ValueError(
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
f"image-free prompt."
)
return tokens
class _JoyImageClipModel(sd1_clip.SDClipModel):
"""Qwen3-VL multimodal encoder wrapper.
``layer="hidden", layer_idx=-1`` + ``layer_norm_hidden_state=False`` is the
pre-norm hook: `SDClipModel.forward` calls the transformer with
``intermediate_output=-1`` (resolved to ``num_layers - 1``) and
``final_layer_norm_intermediate=False``, so the captured intermediate is
the **post-layer-N, pre-final-norm** output of the last decoder layer —
NOT the post-norm ``last_hidden_state``. **Do NOT 'simplify' to
layer="last" / final_layer_norm_intermediate=True**: that returns the
post-norm output, which differs by ~10x in scale (std approx 21 vs 2)
and produces broken DiT outputs.
"""
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
attention_mask=True, model_options={}):
super().__init__(
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
return_attention_masks=attention_mask, model_options=model_options,
)
class JoyImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device, dtype=dtype, name="qwen3vl_8b",
clip_model=_JoyImageClipModel, model_options=model_options,
)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
# embedding sequence and the attention mask.
if out.shape[1] <= JOYIMAGE_DROP_IDX:
raise ValueError(
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
f"template prefix."
)
out = out[:, JOYIMAGE_DROP_IDX:]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class JoyImageTEModel_(JoyImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return JoyImageTEModel_