mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward)
JoyImageEditPlus is the multi-image (1-6 reference images) variant of JoyImageEdit, trained from the same base. Its diffusers transformer shares byte-identical weight structure with the single-image variant (894 keys, zero rename) but injects references differently: instead of the single-image slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim, which forces all items to share resolution), each reference is independently patchified and concatenated on the sequence dim with per-image temporal-offset 3D RoPE, allowing references at different resolutions. Since the single-image port is not yet upstream, this unifies both variants onto the Plus-style forward rather than keeping two paths; single-image is now the ref=1 special case. Verified numerically: at ref=1 with equal resolution the new path's RoPE is bit-identical to the old slot-stack layout, and the transformer output matches the diffusers Plus reference (fp32, incl. the different-resolution case). ComfyUI runs cond/uncond in one forward with a shared reference configuration, so the diffusers Plus batched RoPE, padding attention_mask, and dedicated attention processor are unnecessary here: the unified forward reuses the existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed equivalent to the diffusers batched+mask path for a single sample. - comfy/ldm/joyimage/model.py: forward takes ref_latents and builds components=[target, ref0, ...]; per-component patchify + temporal-offset RoPE; output keeps only the target segment. Old single-grid RoPE removed. - comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation / shape-equality path in _apply_model, passing ref_latents straight to the transformer. Guidance-rescale and the reference_latents requirement are kept. - comfy/text_encoders/joyimage.py: the image template emits one vision block per reference (N = image count); N=1 is byte-for-byte the old template. - comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional image1..image6 inputs, each bucket-resized and VAE-encoded into the reference_latents list. Detection, supported_models, and sd.py need no changes: the identical weight structure routes both variants through image_model="joyimage".
This commit is contained in:
parent
e96bd48e2d
commit
e29384be0d
@ -292,8 +292,6 @@ class _PixArtAlphaTextProjection(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class JoyImageTransformer3DModel(nn.Module):
|
class JoyImageTransformer3DModel(nn.Module):
|
||||||
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: list = [1, 2, 2],
|
patch_size: list = [1, 2, 2],
|
||||||
@ -373,54 +371,54 @@ class JoyImageTransformer3DModel(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_rotary_pos_embed(
|
def _get_rotary_pos_embed_for_range(
|
||||||
self,
|
self,
|
||||||
vis_rope_size,
|
start: Tuple[int, int, int],
|
||||||
txt_rope_size: Optional[int] = None,
|
stop: Tuple[int, int, int],
|
||||||
device=None,
|
device=None,
|
||||||
):
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
target_ndim = 3
|
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
|
||||||
vis_rope_size = list(vis_rope_size)
|
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
|
||||||
if len(vis_rope_size) != target_ndim:
|
|
||||||
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
|
||||||
|
|
||||||
head_dim = self.hidden_size // self.num_attention_heads
|
head_dim = self.hidden_size // self.num_attention_heads
|
||||||
rope_dim_list = self.rope_dim_list
|
rope_dim_list = self.rope_dim_list
|
||||||
if rope_dim_list is None:
|
if rope_dim_list is None:
|
||||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
rope_dim_list = [head_dim // 3 for _ in range(3)]
|
||||||
if sum(rope_dim_list) != head_dim:
|
if sum(rope_dim_list) != head_dim:
|
||||||
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
||||||
|
|
||||||
grid = torch.stack(
|
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
|
||||||
torch.meshgrid(
|
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
|
||||||
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
|
|
||||||
indexing="ij",
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
vis_cos, vis_sin = [], []
|
cos_parts, sin_parts = [], []
|
||||||
for i, dim in enumerate(rope_dim_list):
|
for i, dim in enumerate(rope_dim_list):
|
||||||
pos = grid[i].reshape(-1)
|
pos = mesh[i].reshape(-1)
|
||||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||||
freqs = torch.outer(pos.float(), freqs)
|
angles = torch.outer(pos, freqs)
|
||||||
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
|
||||||
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
|
||||||
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
|
|
||||||
|
|
||||||
if txt_rope_size is None:
|
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
|
||||||
return vis_freqs, None
|
|
||||||
|
|
||||||
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
|
def get_rotary_pos_embed_for_components(
|
||||||
txt_cos, txt_sin = [], []
|
self,
|
||||||
for i, dim in enumerate(rope_dim_list):
|
component_sizes,
|
||||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
device=None,
|
||||||
freqs = torch.outer(grid_txt.float(), freqs)
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
|
||||||
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
|
||||||
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
|
# continues from the running offset, giving every image its own temporal position band.
|
||||||
|
cos_parts, sin_parts = [], []
|
||||||
return vis_freqs, txt_freqs
|
t_offset = 0
|
||||||
|
for (t, h, w) in component_sizes:
|
||||||
|
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
|
||||||
|
start=(t_offset, 0, 0),
|
||||||
|
stop=(t_offset + t, h, w),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
cos_parts.append(cos_emb)
|
||||||
|
sin_parts.append(sin_emb)
|
||||||
|
t_offset += t
|
||||||
|
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
|
||||||
|
|
||||||
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
||||||
c = self.out_channels
|
c = self.out_channels
|
||||||
@ -436,25 +434,57 @@ class JoyImageTransformer3DModel(nn.Module):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
ref_latents=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
_, _, ot, oh, ow = hidden_states.shape
|
# The target noise latent and each reference latent are independently patchified by img_in
|
||||||
tt = ot // self.patch_size[0]
|
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
||||||
th = oh // self.patch_size[1]
|
# RoPE is built per component so references may differ in resolution. Only the leading
|
||||||
tw = ow // self.patch_size[2]
|
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
|
||||||
|
# A single reference is simply the len(ref_latents) == 1 case.
|
||||||
|
if hidden_states.ndim != 5:
|
||||||
|
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
|
||||||
|
|
||||||
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
_, _, ot, oh, ow = hidden_states.shape
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||||
|
)
|
||||||
|
tt = ot // pt
|
||||||
|
th = oh // ph
|
||||||
|
tw = ow // pw
|
||||||
|
|
||||||
|
components = [hidden_states]
|
||||||
|
if ref_latents is not None:
|
||||||
|
for r in ref_latents:
|
||||||
|
if r.ndim != 5:
|
||||||
|
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
|
||||||
|
components.append(r)
|
||||||
|
|
||||||
|
component_sizes = []
|
||||||
|
img_tokens = []
|
||||||
|
for comp in components:
|
||||||
|
_, _, ct, ch, cw = comp.shape
|
||||||
|
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||||
|
)
|
||||||
|
component_sizes.append((ct // pt, ch // ph, cw // pw))
|
||||||
|
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
|
||||||
|
img_tokens.append(tokens)
|
||||||
|
|
||||||
|
img = torch.cat(img_tokens, dim=1)
|
||||||
|
|
||||||
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||||
if vec.shape[-1] > self.hidden_size:
|
if vec.shape[-1] > self.hidden_size:
|
||||||
vec = vec.unflatten(1, (6, -1))
|
vec = vec.unflatten(1, (6, -1))
|
||||||
|
|
||||||
txt_seq_len = txt.shape[1]
|
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
|
||||||
|
component_sizes,
|
||||||
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
|
|
||||||
vis_rope_size=[tt, th, tw],
|
|
||||||
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
|
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
)
|
)
|
||||||
|
vis_freqs = (vis_cos, vis_sin)
|
||||||
|
txt_freqs = None
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for block in self.double_blocks:
|
||||||
img, txt = block(
|
img, txt = block(
|
||||||
@ -465,5 +495,7 @@ class JoyImageTransformer3DModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
img = self.proj_out(self.norm_out(img))
|
img = self.proj_out(self.norm_out(img))
|
||||||
|
target_tokens = tt * th * tw
|
||||||
|
img = img[:, :target_tokens, :]
|
||||||
img = self.unpatchify(img, tt, th, tw)
|
img = self.unpatchify(img, tt, th, tw)
|
||||||
return img
|
return img
|
||||||
|
|||||||
@ -2131,8 +2131,9 @@ class QwenImage(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
class JoyImage(BaseModel):
|
class JoyImage(BaseModel):
|
||||||
# JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale,
|
# The noise latent and every reference latent are concatenated as a token sequence inside the
|
||||||
# are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out.
|
# transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG
|
||||||
|
# guidance rescale is installed from here.
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
|
||||||
self.memory_usage_factor_conds = ("ref_latents",)
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
@ -2177,8 +2178,9 @@ class JoyImage(BaseModel):
|
|||||||
if ref_latents is None or len(ref_latents) == 0:
|
if ref_latents is None or len(ref_latents) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
|
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
|
||||||
"reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. "
|
"reference_latents. Wire the same reference image(s) and vae into both the positive and "
|
||||||
"Empty negative prompts still need image+vae wired."
|
"negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative "
|
||||||
|
"prompts still need the image(s) and vae."
|
||||||
)
|
)
|
||||||
latents = []
|
latents = []
|
||||||
for lat in ref_latents:
|
for lat in ref_latents:
|
||||||
@ -2194,8 +2196,8 @@ class JoyImage(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
# 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list)
|
# Pass the noise latent and the reference latents to the transformer, which patchifies each
|
||||||
# into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation.
|
# component and concatenates them along the sequence dim. References may be any resolution.
|
||||||
if c_concat is not None:
|
if c_concat is not None:
|
||||||
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
|
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
|
||||||
self._ensure_guidance_rescale_installed()
|
self._ensure_guidance_rescale_installed()
|
||||||
@ -2225,38 +2227,26 @@ class JoyImage(BaseModel):
|
|||||||
if ref_latents is None or len(ref_latents) == 0:
|
if ref_latents is None or len(ref_latents) == 0:
|
||||||
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
|
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
|
||||||
|
|
||||||
# Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate
|
if xc.ndim != 5:
|
||||||
# [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W).
|
raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape)))
|
||||||
b, c, t_noise, h, w = xc.shape
|
|
||||||
ref_5d = []
|
refs = []
|
||||||
for r in ref_latents:
|
for r in ref_latents:
|
||||||
if r.shape[-3:] != xc.shape[-3:]:
|
if r.ndim != 5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format(
|
"JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape))
|
||||||
tuple(r.shape), tuple(xc.shape)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
ref_5d.append(r.to(device=device, dtype=dtype))
|
refs.append(r.to(device=device, dtype=dtype))
|
||||||
stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W)
|
|
||||||
n = stacked.shape[1]
|
|
||||||
rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front
|
|
||||||
flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w)
|
|
||||||
|
|
||||||
if control is not None:
|
if control is not None:
|
||||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||||
|
|
||||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does
|
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
||||||
# not accept control/_options/extra_conds. Pass context positionally; the text-encoder
|
# ref_latents); it does not accept control/_options/other extra_conds.
|
||||||
# output IS what's threaded into encoder_hidden_states.
|
|
||||||
if extra_conds:
|
if extra_conds:
|
||||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||||
|
|
||||||
model_output = self.diffusion_model(flat, t_in, context)
|
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs)
|
||||||
|
|
||||||
# After the rotation noise sat at slot 0; pluck it back out from the n*T axis.
|
|
||||||
c_out = model_output.shape[1]
|
|
||||||
out_6d = model_output.reshape(b, c_out, n, t_noise, h, w)
|
|
||||||
noise_pred = out_6d[:, :, 0] # (B, C, T, H, W)
|
|
||||||
|
|
||||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||||
|
|
||||||
|
|||||||
@ -13,9 +13,10 @@ import torch.nn.functional as F
|
|||||||
from comfy import sd1_clip
|
from comfy import sd1_clip
|
||||||
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
|
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
|
||||||
|
|
||||||
# Prompt templates for the text-only and image-conditioned modes. The
|
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
|
||||||
# image-conditioned template wraps the user text with a single
|
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
|
||||||
# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call.
|
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
|
||||||
|
# `{prompt}` with the user text.
|
||||||
JOYIMAGE_TEMPLATE_TEXT = (
|
JOYIMAGE_TEMPLATE_TEXT = (
|
||||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||||
@ -25,9 +26,12 @@ JOYIMAGE_TEMPLATE_TEXT = (
|
|||||||
JOYIMAGE_TEMPLATE_IMAGE = (
|
JOYIMAGE_TEMPLATE_IMAGE = (
|
||||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A single vision block; N copies are concatenated to condition on N reference images.
|
||||||
|
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||||
|
|
||||||
# Number of leading template tokens (system prompt + the user block's opening
|
# Number of leading template tokens (system prompt + the user block's opening
|
||||||
# `<|im_start|>`) stripped from the encoded output by
|
# `<|im_start|>`) stripped from the encoded output by
|
||||||
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
|
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
|
||||||
@ -165,12 +169,14 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
|
|||||||
"""JoyImageEdit tokenizer.
|
"""JoyImageEdit tokenizer.
|
||||||
|
|
||||||
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
||||||
template when one or more image tensors are passed, otherwise the text-only
|
template when one or more image tensors are passed, emitting one
|
||||||
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
|
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
|
||||||
with an embedding marker so `SDClipModel.process_tokens` routes the image
|
for N reference images), otherwise the text-only template. Each
|
||||||
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
|
``<|image_pad|>`` token in the formatted prompt is replaced with an
|
||||||
template tokens are stripped downstream by
|
embedding marker so `SDClipModel.process_tokens` routes each image through
|
||||||
`JoyImageTEModel.encode_token_weights`. No ``<think>`` block is appended.
|
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
|
||||||
|
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
|
||||||
|
No ``<think>`` block is appended.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -188,7 +194,9 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
|
|||||||
elif llama_template is not None:
|
elif llama_template is not None:
|
||||||
llama_text = llama_template.format(text)
|
llama_text = llama_template.format(text)
|
||||||
elif len(images) > 0:
|
elif len(images) > 0:
|
||||||
llama_text = self.llama_template_images.format(text)
|
# One vision block per reference image.
|
||||||
|
vision = JOYIMAGE_VISION_BLOCK * len(images)
|
||||||
|
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
|
||||||
else:
|
else:
|
||||||
llama_text = self.llama_template.format(text)
|
llama_text = self.llama_template.format(text)
|
||||||
|
|
||||||
|
|||||||
@ -76,11 +76,80 @@ class TextEncodeJoyImageEdit(io.ComfyNode):
|
|||||||
return io.NodeOutput(conditioning, resized_image)
|
return io.NodeOutput(conditioning, resized_image)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodeJoyImageEditPlus(io.ComfyNode):
|
||||||
|
"""JoyImageEdit multi-image (Plus) text-encode node.
|
||||||
|
|
||||||
|
Accepts 1-6 optional reference images. Each supplied image is
|
||||||
|
bucket-resized independently (same buckets/resize as the single-image
|
||||||
|
node), VAE-encoded, and appended in order to
|
||||||
|
``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...).
|
||||||
|
All resized images are passed to the VL tower in one call; the tokenizer
|
||||||
|
emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
MAX_IMAGES = 6
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TextEncodeJoyImageEditPlus",
|
||||||
|
category="advanced/conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("image1", optional=True),
|
||||||
|
io.Image.Input("image2", optional=True),
|
||||||
|
io.Image.Input("image3", optional=True),
|
||||||
|
io.Image.Input("image4", optional=True),
|
||||||
|
io.Image.Input("image5", optional=True),
|
||||||
|
io.Image.Input("image6", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
io.Image.Output(display_name="image"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None,
|
||||||
|
image4=None, image5=None, image6=None) -> io.NodeOutput:
|
||||||
|
images = [image1, image2, image3, image4, image5, image6]
|
||||||
|
supplied = [img for img in images if img is not None]
|
||||||
|
if len(supplied) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"TextEncodeJoyImageEditPlus requires at least one reference image."
|
||||||
|
)
|
||||||
|
|
||||||
|
resized_images = []
|
||||||
|
ref_latents = []
|
||||||
|
for image in supplied:
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
src_h, src_w = samples.shape[2], samples.shape[3]
|
||||||
|
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
|
||||||
|
|
||||||
|
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
|
||||||
|
resized_image = resized.movedim(1, -1)[:, :, :, :3]
|
||||||
|
resized_images.append(resized_image)
|
||||||
|
ref_latents.append(vae.encode(resized_image))
|
||||||
|
|
||||||
|
tokens = clip.tokenize(prompt, images=resized_images)
|
||||||
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
|
conditioning = node_helpers.conditioning_set_values(
|
||||||
|
conditioning, {"reference_latents": ref_latents}, append=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The last reference sets the target resolution; return it for VAEEncode and the
|
||||||
|
# matching negative encode.
|
||||||
|
return io.NodeOutput(conditioning, resized_images[-1])
|
||||||
|
|
||||||
|
|
||||||
class JoyImageExtension(ComfyExtension):
|
class JoyImageExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TextEncodeJoyImageEdit,
|
TextEncodeJoyImageEdit,
|
||||||
|
TextEncodeJoyImageEditPlus,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user