Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward)

JoyImageEditPlus is the multi-image (1-6 reference images) variant of
JoyImageEdit, trained from the same base. Its diffusers transformer shares
byte-identical weight structure with the single-image variant (894 keys, zero
rename) but injects references differently: instead of the single-image
slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim,
which forces all items to share resolution), each reference is independently
patchified and concatenated on the sequence dim with per-image temporal-offset
3D RoPE, allowing references at different resolutions.

Since the single-image port is not yet upstream, this unifies both variants
onto the Plus-style forward rather than keeping two paths; single-image is now
the ref=1 special case. Verified numerically: at ref=1 with equal resolution
the new path's RoPE is bit-identical to the old slot-stack layout, and the
transformer output matches the diffusers Plus reference (fp32, incl. the
different-resolution case).

ComfyUI runs cond/uncond in one forward with a shared reference configuration,
so the diffusers Plus batched RoPE, padding attention_mask, and dedicated
attention processor are unnecessary here: the unified forward reuses the
existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed
equivalent to the diffusers batched+mask path for a single sample.

- comfy/ldm/joyimage/model.py: forward takes ref_latents and builds
  components=[target, ref0, ...]; per-component patchify + temporal-offset
  RoPE; output keeps only the target segment. Old single-grid RoPE removed.
- comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation /
  shape-equality path in _apply_model, passing ref_latents straight to the
  transformer. Guidance-rescale and the reference_latents requirement are kept.
- comfy/text_encoders/joyimage.py: the image template emits one vision block
  per reference (N = image count); N=1 is byte-for-byte the old template.
- comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional
  image1..image6 inputs, each bucket-resized and VAE-encoded into the
  reference_latents list.

Detection, supported_models, and sd.py need no changes: the identical weight
structure routes both variants through image_model="joyimage".
This commit is contained in:
huangfeice 2026-07-01 16:15:40 +08:00
parent e96bd48e2d
commit e29384be0d
4 changed files with 185 additions and 86 deletions

View File

@ -292,8 +292,6 @@ class _PixArtAlphaTextProjection(nn.Module):
class JoyImageTransformer3DModel(nn.Module):
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
def __init__(
self,
patch_size: list = [1, 2, 2],
@ -373,54 +371,54 @@ class JoyImageTransformer3DModel(nn.Module):
device=device,
)
def get_rotary_pos_embed(
def _get_rotary_pos_embed_for_range(
self,
vis_rope_size,
txt_rope_size: Optional[int] = None,
start: Tuple[int, int, int],
stop: Tuple[int, int, int],
device=None,
):
target_ndim = 3
vis_rope_size = list(vis_rope_size)
if len(vis_rope_size) != target_ndim:
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
) -> Tuple[torch.Tensor, torch.Tensor]:
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
head_dim = self.hidden_size // self.num_attention_heads
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
rope_dim_list = [head_dim // 3 for _ in range(3)]
if sum(rope_dim_list) != head_dim:
raise ValueError("sum(rope_dim_list) should equal head_dim")
grid = torch.stack(
torch.meshgrid(
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
indexing="ij",
),
dim=0,
)
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
vis_cos, vis_sin = [], []
cos_parts, sin_parts = [], []
for i, dim in enumerate(rope_dim_list):
pos = grid[i].reshape(-1)
pos = mesh[i].reshape(-1)
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(pos.float(), freqs)
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
angles = torch.outer(pos, freqs)
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
if txt_rope_size is None:
return vis_freqs, None
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
txt_cos, txt_sin = [], []
for i, dim in enumerate(rope_dim_list):
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
freqs = torch.outer(grid_txt.float(), freqs)
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
return vis_freqs, txt_freqs
def get_rotary_pos_embed_for_components(
self,
component_sizes,
device=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
# continues from the running offset, giving every image its own temporal position band.
cos_parts, sin_parts = [], []
t_offset = 0
for (t, h, w) in component_sizes:
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
start=(t_offset, 0, 0),
stop=(t_offset + t, h, w),
device=device,
)
cos_parts.append(cos_emb)
sin_parts.append(sin_emb)
t_offset += t
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
c = self.out_channels
@ -436,25 +434,57 @@ class JoyImageTransformer3DModel(nn.Module):
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
ref_latents=None,
) -> torch.Tensor:
_, _, ot, oh, ow = hidden_states.shape
tt = ot // self.patch_size[0]
th = oh // self.patch_size[1]
tw = ow // self.patch_size[2]
# The target noise latent and each reference latent are independently patchified by img_in
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
# RoPE is built per component so references may differ in resolution. Only the leading
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
# A single reference is simply the len(ref_latents) == 1 case.
if hidden_states.ndim != 5:
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
_, _, ot, oh, ow = hidden_states.shape
pt, ph, pw = self.patch_size
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
raise ValueError(
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
)
tt = ot // pt
th = oh // ph
tw = ow // pw
components = [hidden_states]
if ref_latents is not None:
for r in ref_latents:
if r.ndim != 5:
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
components.append(r)
component_sizes = []
img_tokens = []
for comp in components:
_, _, ct, ch, cw = comp.shape
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
raise ValueError(
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
)
component_sizes.append((ct // pt, ch // ph, cw // pw))
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
img_tokens.append(tokens)
img = torch.cat(img_tokens, dim=1)
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
if vec.shape[-1] > self.hidden_size:
vec = vec.unflatten(1, (6, -1))
txt_seq_len = txt.shape[1]
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
vis_rope_size=[tt, th, tw],
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
component_sizes,
device=hidden_states.device,
)
vis_freqs = (vis_cos, vis_sin)
txt_freqs = None
for block in self.double_blocks:
img, txt = block(
@ -465,5 +495,7 @@ class JoyImageTransformer3DModel(nn.Module):
)
img = self.proj_out(self.norm_out(img))
target_tokens = tt * th * tw
img = img[:, :target_tokens, :]
img = self.unpatchify(img, tt, th, tw)
return img

View File

@ -2131,8 +2131,9 @@ class QwenImage(BaseModel):
return out
class JoyImage(BaseModel):
# JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale,
# are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out.
# The noise latent and every reference latent are concatenated as a token sequence inside the
# transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG
# guidance rescale is installed from here.
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
self.memory_usage_factor_conds = ("ref_latents",)
@ -2177,8 +2178,9 @@ class JoyImage(BaseModel):
if ref_latents is None or len(ref_latents) == 0:
raise ValueError(
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
"reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. "
"Empty negative prompts still need image+vae wired."
"reference_latents. Wire the same reference image(s) and vae into both the positive and "
"negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative "
"prompts still need the image(s) and vae."
)
latents = []
for lat in ref_latents:
@ -2194,8 +2196,8 @@ class JoyImage(BaseModel):
return out
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list)
# into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation.
# Pass the noise latent and the reference latents to the transformer, which patchifies each
# component and concatenates them along the sequence dim. References may be any resolution.
if c_concat is not None:
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
self._ensure_guidance_rescale_installed()
@ -2225,38 +2227,26 @@ class JoyImage(BaseModel):
if ref_latents is None or len(ref_latents) == 0:
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
# Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate
# [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W).
b, c, t_noise, h, w = xc.shape
ref_5d = []
if xc.ndim != 5:
raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape)))
refs = []
for r in ref_latents:
if r.shape[-3:] != xc.shape[-3:]:
if r.ndim != 5:
raise ValueError(
"JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format(
tuple(r.shape), tuple(xc.shape)
)
"JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape))
)
ref_5d.append(r.to(device=device, dtype=dtype))
stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W)
n = stacked.shape[1]
rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front
flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w)
refs.append(r.to(device=device, dtype=dtype))
if control is not None:
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does
# not accept control/_options/extra_conds. Pass context positionally; the text-encoder
# output IS what's threaded into encoder_hidden_states.
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
# ref_latents); it does not accept control/_options/other extra_conds.
if extra_conds:
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
model_output = self.diffusion_model(flat, t_in, context)
# After the rotation noise sat at slot 0; pluck it back out from the n*T axis.
c_out = model_output.shape[1]
out_6d = model_output.reshape(b, c_out, n, t_noise, h, w)
noise_pred = out_6d[:, :, 0] # (B, C, T, H, W)
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs)
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)

View File

@ -13,9 +13,10 @@ import torch.nn.functional as F
from comfy import sd1_clip
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
# Prompt templates for the text-only and image-conditioned modes. The
# image-conditioned template wraps the user text with a single
# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call.
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
# `{prompt}` with the user text.
JOYIMAGE_TEMPLATE_TEXT = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
@ -25,9 +26,12 @@ JOYIMAGE_TEMPLATE_TEXT = (
JOYIMAGE_TEMPLATE_IMAGE = (
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
)
# A single vision block; N copies are concatenated to condition on N reference images.
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
# Number of leading template tokens (system prompt + the user block's opening
# `<|im_start|>`) stripped from the encoded output by
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
@ -165,12 +169,14 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
"""JoyImageEdit tokenizer.
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
template when one or more image tensors are passed, otherwise the text-only
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
with an embedding marker so `SDClipModel.process_tokens` routes the image
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
template tokens are stripped downstream by
`JoyImageTEModel.encode_token_weights`. No ``<think>`` block is appended.
template when one or more image tensors are passed, emitting one
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
for N reference images), otherwise the text-only template. Each
``<|image_pad|>`` token in the formatted prompt is replaced with an
embedding marker so `SDClipModel.process_tokens` routes each image through
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
No ``<think>`` block is appended.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -188,7 +194,9 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
elif llama_template is not None:
llama_text = llama_template.format(text)
elif len(images) > 0:
llama_text = self.llama_template_images.format(text)
# One vision block per reference image.
vision = JOYIMAGE_VISION_BLOCK * len(images)
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
else:
llama_text = self.llama_template.format(text)

View File

@ -76,11 +76,80 @@ class TextEncodeJoyImageEdit(io.ComfyNode):
return io.NodeOutput(conditioning, resized_image)
class TextEncodeJoyImageEditPlus(io.ComfyNode):
"""JoyImageEdit multi-image (Plus) text-encode node.
Accepts 1-6 optional reference images. Each supplied image is
bucket-resized independently (same buckets/resize as the single-image
node), VAE-encoded, and appended in order to
``conditioning["reference_latents"]`` (image1 ref0, image2 ref1, ...).
All resized images are passed to the VL tower in one call; the tokenizer
emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image.
"""
MAX_IMAGES = 6
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeJoyImageEditPlus",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae"),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
io.Image.Input("image4", optional=True),
io.Image.Input("image5", optional=True),
io.Image.Input("image6", optional=True),
],
outputs=[
io.Conditioning.Output(),
io.Image.Output(display_name="image"),
],
)
@classmethod
def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None,
image4=None, image5=None, image6=None) -> io.NodeOutput:
images = [image1, image2, image3, image4, image5, image6]
supplied = [img for img in images if img is not None]
if len(supplied) == 0:
raise ValueError(
"TextEncodeJoyImageEditPlus requires at least one reference image."
)
resized_images = []
ref_latents = []
for image in supplied:
samples = image.movedim(-1, 1)
src_h, src_w = samples.shape[2], samples.shape[3]
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
resized_image = resized.movedim(1, -1)[:, :, :, :3]
resized_images.append(resized_image)
ref_latents.append(vae.encode(resized_image))
tokens = clip.tokenize(prompt, images=resized_images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
conditioning = node_helpers.conditioning_set_values(
conditioning, {"reference_latents": ref_latents}, append=True,
)
# The last reference sets the target resolution; return it for VAEEncode and the
# matching negative encode.
return io.NodeOutput(conditioning, resized_images[-1])
class JoyImageExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeJoyImageEdit,
TextEncodeJoyImageEditPlus,
]