ComfyUI/comfy/text_encoders/joyimage.py
huangfeice e29384be0d Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward)
JoyImageEditPlus is the multi-image (1-6 reference images) variant of
JoyImageEdit, trained from the same base. Its diffusers transformer shares
byte-identical weight structure with the single-image variant (894 keys, zero
rename) but injects references differently: instead of the single-image
slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim,
which forces all items to share resolution), each reference is independently
patchified and concatenated on the sequence dim with per-image temporal-offset
3D RoPE, allowing references at different resolutions.

Since the single-image port is not yet upstream, this unifies both variants
onto the Plus-style forward rather than keeping two paths; single-image is now
the ref=1 special case. Verified numerically: at ref=1 with equal resolution
the new path's RoPE is bit-identical to the old slot-stack layout, and the
transformer output matches the diffusers Plus reference (fp32, incl. the
different-resolution case).

ComfyUI runs cond/uncond in one forward with a shared reference configuration,
so the diffusers Plus batched RoPE, padding attention_mask, and dedicated
attention processor are unnecessary here: the unified forward reuses the
existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed
equivalent to the diffusers batched+mask path for a single sample.

- comfy/ldm/joyimage/model.py: forward takes ref_latents and builds
  components=[target, ref0, ...]; per-component patchify + temporal-offset
  RoPE; output keeps only the target segment. Old single-grid RoPE removed.
- comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation /
  shape-equality path in _apply_model, passing ref_latents straight to the
  transformer. Guidance-rescale and the reference_latents requirement are kept.
- comfy/text_encoders/joyimage.py: the image template emits one vision block
  per reference (N = image count); N=1 is byte-for-byte the old template.
- comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional
  image1..image6 inputs, each bucket-resized and VAE-encoded into the
  reference_latents list.

Detection, supported_models, and sd.py need no changes: the identical weight
structure routes both variants through image_model="joyimage".
2026-07-01 18:36:43 +08:00

281 lines
12 KiB
Python

"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the
JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the
JoyImage-specific prompt templates, system-prompt strip, image preprocessing,
and conditioning-path multimodal handling.
"""
import math
from typing import List, Optional
import torch
import torch.nn.functional as F
from comfy import sd1_clip
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
# `{prompt}` with the user text.
JOYIMAGE_TEMPLATE_TEXT = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
)
JOYIMAGE_TEMPLATE_IMAGE = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
# A single vision block; N copies are concatenated to condition on N reference images.
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
# Number of leading template tokens (system prompt + the user block's opening
# `<|im_start|>`) stripped from the encoded output by
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
# `user` token.
JOYIMAGE_DROP_IDX = 34
# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936).
IMAGE_PAD_TOKEN = 151655
PAD_TOKEN = 151643
# ---------------------------------------------------------------------------
# Image preprocessing
# ---------------------------------------------------------------------------
def process_qwen3vl_image(
image: torch.Tensor,
min_pixels: int = 65536,
max_pixels: int = 16777216,
patch_size: int = 16,
temporal_patch_size: int = 2,
merge_size: int = 2,
image_mean: Optional[List[float]] = None,
image_std: Optional[List[float]] = None,
):
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower.
Uses bicubic interpolation followed by ``clamp(0, 1)``.
"""
if image_mean is None:
image_mean = [0.5, 0.5, 0.5]
if image_std is None:
image_std = [0.5, 0.5, 0.5]
if image.dim() == 3:
image = image.unsqueeze(0)
batch, height, width, channels = image.shape
if batch != 1:
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
device = image.device
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
img = image[0]
factor = patch_size * merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
img_resized = F.interpolate(
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
).squeeze(0).clamp(0.0, 1.0)
normalized = img_resized.clone()
for c in range(3):
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
grid_h = h_bar // patch_size
grid_w = w_bar // patch_size
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
grid_t = 1
channel = pixel_values.shape[1]
patches = pixel_values.reshape(
grid_t, temporal_patch_size, channel,
grid_h // merge_size, merge_size, patch_size,
grid_w // merge_size, merge_size, patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
return flatten_patches, grid_thw
class Qwen3VL8B_JoyImage(Qwen3VL):
"""JoyImage Qwen3-VL-8B encoder.
Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved
3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27,
patch_size 16, deepstack_visual_indexes=[8,16,24]).
"""
model_type = "qwen3vl_8b"
def preprocess_embed(self, embed, device):
# Run the vision tower with JoyImage's bicubic+clamp preprocessing and
# return ``(merged, {"grid", "deepstack"})``.
if embed["type"] == "image":
image, grid = process_qwen3vl_image(
embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
)
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
return merged, {"grid": grid, "deepstack": deepstack}
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
intermediate_output=None, final_layer_norm_intermediate=True,
dtype=None, embeds_info=()):
# The conditioning path must build the 3D MRoPE position ids for the
# image-token block and inject the deepstack visual features.
# `build_image_inputs` returns the kwargs the decoder expects:
# (position_ids, visual_pos_masks, deepstack).
if embeds is not None:
position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info)
else:
position_ids, visual_pos_masks, deepstack = None, None, None
return self.model(
x,
attention_mask=attention_mask,
embeds=embeds,
num_tokens=num_tokens,
intermediate_output=intermediate_output,
final_layer_norm_intermediate=final_layer_norm_intermediate,
dtype=dtype,
position_ids=position_ids,
deepstack_embeds=deepstack,
visual_pos_masks=visual_pos_masks,
)
class JoyImageTokenizer(Qwen3VLTokenizer):
"""JoyImageEdit tokenizer.
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
template when one or more image tensors are passed, emitting one
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
for N reference images), otherwise the text-only template. Each
``<|image_pad|>`` token in the formatted prompt is replaced with an
embedding marker so `SDClipModel.process_tokens` routes each image through
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
No ``<think>`` block is appended.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
model_type="qwen3vl_8b",
)
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
images=[], **kwargs):
if text.startswith("<|im_start|>"):
llama_text = text
elif llama_template is not None:
llama_text = llama_template.format(text)
elif len(images) > 0:
# One vision block per reference image.
vision = JOYIMAGE_VISION_BLOCK * len(images)
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
else:
llama_text = self.llama_template.format(text)
# Tokenize the already-rendered template via the grandparent
# (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template.
tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights(
self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == IMAGE_PAD_TOKEN:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count],
"original_type": "image"},) + r[i][1:]
embed_count += 1
if embed_count != len(images):
raise ValueError(
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
f"image-free prompt."
)
return tokens
class _JoyImageClipModel(sd1_clip.SDClipModel):
"""Qwen3-VL multimodal encoder wrapper.
Conditions on the **pre-final-norm** output of the last decoder layer
(``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The
post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken
DiT outputs, so these flags must not be changed.
"""
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
attention_mask=True, model_options={}):
super().__init__(
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
return_attention_masks=attention_mask, model_options=model_options,
)
class JoyImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(
device=device, dtype=dtype, name="qwen3vl_8b",
clip_model=_JoyImageClipModel, model_options=model_options,
)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
# embedding sequence and the attention mask.
if out.shape[1] <= JOYIMAGE_DROP_IDX:
raise ValueError(
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
f"template prefix."
)
out = out[:, JOYIMAGE_DROP_IDX:]
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class JoyImageTEModel_(JoyImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return JoyImageTEModel_