diff --git a/comfy/ldm/joyimage/model.py b/comfy/ldm/joyimage/model.py index e7c8cf9ce..a9640cb7c 100644 --- a/comfy/ldm/joyimage/model.py +++ b/comfy/ldm/joyimage/model.py @@ -292,8 +292,6 @@ class _PixArtAlphaTextProjection(nn.Module): class JoyImageTransformer3DModel(nn.Module): - # 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out. - def __init__( self, patch_size: list = [1, 2, 2], @@ -373,54 +371,54 @@ class JoyImageTransformer3DModel(nn.Module): device=device, ) - def get_rotary_pos_embed( + def _get_rotary_pos_embed_for_range( self, - vis_rope_size, - txt_rope_size: Optional[int] = None, + start: Tuple[int, int, int], + stop: Tuple[int, int, int], device=None, - ): - target_ndim = 3 - vis_rope_size = list(vis_rope_size) - if len(vis_rope_size) != target_ndim: - vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size - + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after + # reshape(-1) is (t, h, w), matching the img_in Conv3d flatten. head_dim = self.hidden_size // self.num_attention_heads rope_dim_list = self.rope_dim_list if rope_dim_list is None: - rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] + rope_dim_list = [head_dim // 3 for _ in range(3)] if sum(rope_dim_list) != head_dim: raise ValueError("sum(rope_dim_list) should equal head_dim") - grid = torch.stack( - torch.meshgrid( - *[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size], - indexing="ij", - ), - dim=0, - ) + grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)] + mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0) - vis_cos, vis_sin = [], [] + cos_parts, sin_parts = [], [] for i, dim in enumerate(rope_dim_list): - pos = grid[i].reshape(-1) + pos = mesh[i].reshape(-1) freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) - freqs = torch.outer(pos.float(), freqs) - vis_cos.append(freqs.cos().repeat_interleave(2, dim=1)) - vis_sin.append(freqs.sin().repeat_interleave(2, dim=1)) - vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1)) + angles = torch.outer(pos, freqs) + cos_parts.append(angles.cos().repeat_interleave(2, dim=1)) + sin_parts.append(angles.sin().repeat_interleave(2, dim=1)) - if txt_rope_size is None: - return vis_freqs, None + return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1) - grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1 - txt_cos, txt_sin = [], [] - for i, dim in enumerate(rope_dim_list): - freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim)) - freqs = torch.outer(grid_txt.float(), freqs) - txt_cos.append(freqs.cos().repeat_interleave(2, dim=1)) - txt_sin.append(freqs.sin().repeat_interleave(2, dim=1)) - txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1)) - - return vis_freqs, txt_freqs + def get_rotary_pos_embed_for_components( + self, + component_sizes, + device=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in + # sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t + # continues from the running offset, giving every image its own temporal position band. + cos_parts, sin_parts = [], [] + t_offset = 0 + for (t, h, w) in component_sizes: + cos_emb, sin_emb = self._get_rotary_pos_embed_for_range( + start=(t_offset, 0, 0), + stop=(t_offset + t, h, w), + device=device, + ) + cos_parts.append(cos_emb) + sin_parts.append(sin_emb) + t_offset += t + return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0) def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor: c = self.out_channels @@ -436,25 +434,57 @@ class JoyImageTransformer3DModel(nn.Module): hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, + ref_latents=None, ) -> torch.Tensor: - _, _, ot, oh, ow = hidden_states.shape - tt = ot // self.patch_size[0] - th = oh // self.patch_size[1] - tw = ow // self.patch_size[2] + # The target noise latent and each reference latent are independently patchified by img_in + # (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...]. + # RoPE is built per component so references may differ in resolution. Only the leading + # target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped. + # A single reference is simply the len(ref_latents) == 1 case. + if hidden_states.ndim != 5: + raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}") - img = self.img_in(hidden_states).flatten(2).transpose(1, 2) + _, _, ot, oh, ow = hidden_states.shape + pt, ph, pw = self.patch_size + if ot % pt != 0 or oh % ph != 0 or ow % pw != 0: + raise ValueError( + f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + tt = ot // pt + th = oh // ph + tw = ow // pw + + components = [hidden_states] + if ref_latents is not None: + for r in ref_latents: + if r.ndim != 5: + raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}") + components.append(r) + + component_sizes = [] + img_tokens = [] + for comp in components: + _, _, ct, ch, cw = comp.shape + if ct % pt != 0 or ch % ph != 0 or cw % pw != 0: + raise ValueError( + f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}" + ) + component_sizes.append((ct // pt, ch // ph, cw // pw)) + tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D) + img_tokens.append(tokens) + + img = torch.cat(img_tokens, dim=1) _, vec, txt = self.condition_embedder(timestep, encoder_hidden_states) if vec.shape[-1] > self.hidden_size: vec = vec.unflatten(1, (6, -1)) - txt_seq_len = txt.shape[1] - - vis_freqs, txt_freqs = self.get_rotary_pos_embed( - vis_rope_size=[tt, th, tw], - txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None, + vis_cos, vis_sin = self.get_rotary_pos_embed_for_components( + component_sizes, device=hidden_states.device, ) + vis_freqs = (vis_cos, vis_sin) + txt_freqs = None for block in self.double_blocks: img, txt = block( @@ -465,5 +495,7 @@ class JoyImageTransformer3DModel(nn.Module): ) img = self.proj_out(self.norm_out(img)) + target_tokens = tt * th * tw + img = img[:, :target_tokens, :] img = self.unpatchify(img, tt, th, tw) return img diff --git a/comfy/model_base.py b/comfy/model_base.py index 964fd9a8c..8b9f93ca2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -2131,8 +2131,9 @@ class QwenImage(BaseModel): return out class JoyImage(BaseModel): - # JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale, - # are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out. + # The noise latent and every reference latent are concatenated as a token sequence inside the + # transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG + # guidance rescale is installed from here. def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel) self.memory_usage_factor_conds = ("ref_latents",) @@ -2177,8 +2178,9 @@ class JoyImage(BaseModel): if ref_latents is None or len(ref_latents) == 0: raise ValueError( "JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry " - "reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. " - "Empty negative prompts still need image+vae wired." + "reference_latents. Wire the same reference image(s) and vae into both the positive and " + "negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative " + "prompts still need the image(s) and vae." ) latents = [] for lat in ref_latents: @@ -2194,8 +2196,8 @@ class JoyImage(BaseModel): return out def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): - # 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list) - # into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation. + # Pass the noise latent and the reference latents to the transformer, which patchifies each + # component and concatenates them along the sequence dim. References may be any resolution. if c_concat is not None: raise ValueError("JoyImage does not support c_concat / noise_concat conditioning") self._ensure_guidance_rescale_installed() @@ -2225,38 +2227,26 @@ class JoyImage(BaseModel): if ref_latents is None or len(ref_latents) == 0: raise ValueError("JoyImageEdit forward requires ref_latents; got none.") - # Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate - # [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W). - b, c, t_noise, h, w = xc.shape - ref_5d = [] + if xc.ndim != 5: + raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape))) + + refs = [] for r in ref_latents: - if r.shape[-3:] != xc.shape[-3:]: + if r.ndim != 5: raise ValueError( - "JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format( - tuple(r.shape), tuple(xc.shape) - ) + "JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape)) ) - ref_5d.append(r.to(device=device, dtype=dtype)) - stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W) - n = stacked.shape[1] - rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front - flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w) + refs.append(r.to(device=device, dtype=dtype)) if control is not None: raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.") - # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does - # not accept control/_options/extra_conds. Pass context positionally; the text-encoder - # output IS what's threaded into encoder_hidden_states. + # The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states, + # ref_latents); it does not accept control/_options/other extra_conds. if extra_conds: raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys()))) - model_output = self.diffusion_model(flat, t_in, context) - - # After the rotation noise sat at slot 0; pluck it back out from the n*T axis. - c_out = model_output.shape[1] - out_6d = model_output.reshape(b, c_out, n, t_noise, h, w) - noise_pred = out_6d[:, :, 0] # (B, C, T, H, W) + noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs) return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x) diff --git a/comfy/text_encoders/joyimage.py b/comfy/text_encoders/joyimage.py index 959a2b164..04dadb949 100644 --- a/comfy/text_encoders/joyimage.py +++ b/comfy/text_encoders/joyimage.py @@ -13,9 +13,10 @@ import torch.nn.functional as F from comfy import sd1_clip from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer -# Prompt templates for the text-only and image-conditioned modes. The -# image-conditioned template wraps the user text with a single -# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call. +# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template +# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference +# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and +# `{prompt}` with the user text. JOYIMAGE_TEMPLATE_TEXT = ( "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" @@ -25,9 +26,12 @@ JOYIMAGE_TEMPLATE_TEXT = ( JOYIMAGE_TEMPLATE_IMAGE = ( "<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, " "quantity, text, spatial relationships of the objects and background:<|im_end|>\n" - "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + "<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n" ) +# A single vision block; N copies are concatenated to condition on N reference images. +JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>" + # Number of leading template tokens (system prompt + the user block's opening # `<|im_start|>`) stripped from the encoded output by # JoyImageTEModel.encode_token_weights, so the kept sequence begins at the @@ -165,12 +169,14 @@ class JoyImageTokenizer(Qwen3VLTokenizer): """JoyImageEdit tokenizer. ``tokenize_with_weights(text, images=[...])`` selects the image-conditioned - template when one or more image tensors are passed, otherwise the text-only - template. Each ``<|image_pad|>`` token in the formatted prompt is replaced - with an embedding marker so `SDClipModel.process_tokens` routes the image - through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading - template tokens are stripped downstream by - `JoyImageTEModel.encode_token_weights`. No ```` block is appended. + template when one or more image tensors are passed, emitting one + ``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks + for N reference images), otherwise the text-only template. Each + ``<|image_pad|>`` token in the formatted prompt is replaced with an + embedding marker so `SDClipModel.process_tokens` routes each image through + `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template + tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`. + No ```` block is appended. """ def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -188,7 +194,9 @@ class JoyImageTokenizer(Qwen3VLTokenizer): elif llama_template is not None: llama_text = llama_template.format(text) elif len(images) > 0: - llama_text = self.llama_template_images.format(text) + # One vision block per reference image. + vision = JOYIMAGE_VISION_BLOCK * len(images) + llama_text = self.llama_template_images.format(vision=vision, prompt=text) else: llama_text = self.llama_template.format(text) diff --git a/comfy_extras/nodes_joyimage.py b/comfy_extras/nodes_joyimage.py index a18eddd09..72c7f3b7f 100644 --- a/comfy_extras/nodes_joyimage.py +++ b/comfy_extras/nodes_joyimage.py @@ -76,11 +76,80 @@ class TextEncodeJoyImageEdit(io.ComfyNode): return io.NodeOutput(conditioning, resized_image) +class TextEncodeJoyImageEditPlus(io.ComfyNode): + """JoyImageEdit multi-image (Plus) text-encode node. + + Accepts 1-6 optional reference images. Each supplied image is + bucket-resized independently (same buckets/resize as the single-image + node), VAE-encoded, and appended in order to + ``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...). + All resized images are passed to the VL tower in one call; the tokenizer + emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image. + """ + + MAX_IMAGES = 6 + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeJoyImageEditPlus", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Vae.Input("vae"), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + io.Image.Input("image4", optional=True), + io.Image.Input("image5", optional=True), + io.Image.Input("image6", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + io.Image.Output(display_name="image"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None, + image4=None, image5=None, image6=None) -> io.NodeOutput: + images = [image1, image2, image3, image4, image5, image6] + supplied = [img for img in images if img is not None] + if len(supplied) == 0: + raise ValueError( + "TextEncodeJoyImageEditPlus requires at least one reference image." + ) + + resized_images = [] + ref_latents = [] + for image in supplied: + samples = image.movedim(-1, 1) + src_h, src_w = samples.shape[2], samples.shape[3] + bucket_h, bucket_w = _find_best_bucket(src_h, src_w) + + resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center") + resized_image = resized.movedim(1, -1)[:, :, :, :3] + resized_images.append(resized_image) + ref_latents.append(vae.encode(resized_image)) + + tokens = clip.tokenize(prompt, images=resized_images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + conditioning = node_helpers.conditioning_set_values( + conditioning, {"reference_latents": ref_latents}, append=True, + ) + + # The last reference sets the target resolution; return it for VAEEncode and the + # matching negative encode. + return io.NodeOutput(conditioning, resized_images[-1]) + + class JoyImageExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeJoyImageEdit, + TextEncodeJoyImageEditPlus, ]