mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add JoyImageEditPlus multi-image edit support (unify onto Plus-style forward)
JoyImageEditPlus is the multi-image (1-6 reference images) variant of JoyImageEdit, trained from the same base. Its diffusers transformer shares byte-identical weight structure with the single-image variant (894 keys, zero rename) but injects references differently: instead of the single-image slot-stack (stack refs + noise into a 6D tensor and rotate on the frame dim, which forces all items to share resolution), each reference is independently patchified and concatenated on the sequence dim with per-image temporal-offset 3D RoPE, allowing references at different resolutions. Since the single-image port is not yet upstream, this unifies both variants onto the Plus-style forward rather than keeping two paths; single-image is now the ref=1 special case. Verified numerically: at ref=1 with equal resolution the new path's RoPE is bit-identical to the old slot-stack layout, and the transformer output matches the diffusers Plus reference (fp32, incl. the different-resolution case). ComfyUI runs cond/uncond in one forward with a shared reference configuration, so the diffusers Plus batched RoPE, padding attention_mask, and dedicated attention processor are unnecessary here: the unified forward reuses the existing unbatched _apply_rotary_emb and JoyImageAttention. Confirmed equivalent to the diffusers batched+mask path for a single sample. - comfy/ldm/joyimage/model.py: forward takes ref_latents and builds components=[target, ref0, ...]; per-component patchify + temporal-offset RoPE; output keeps only the target segment. Old single-grid RoPE removed. - comfy/model_base.py: JoyImage drops the slot-stack / frame-rotation / shape-equality path in _apply_model, passing ref_latents straight to the transformer. Guidance-rescale and the reference_latents requirement are kept. - comfy/text_encoders/joyimage.py: the image template emits one vision block per reference (N = image count); N=1 is byte-for-byte the old template. - comfy_extras/nodes_joyimage.py: add TextEncodeJoyImageEditPlus with optional image1..image6 inputs, each bucket-resized and VAE-encoded into the reference_latents list. Detection, supported_models, and sd.py need no changes: the identical weight structure routes both variants through image_model="joyimage".
This commit is contained in:
parent
e96bd48e2d
commit
e29384be0d
@ -292,8 +292,6 @@ class _PixArtAlphaTextProjection(nn.Module):
|
||||
|
||||
|
||||
class JoyImageTransformer3DModel(nn.Module):
|
||||
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: list = [1, 2, 2],
|
||||
@ -373,54 +371,54 @@ class JoyImageTransformer3DModel(nn.Module):
|
||||
device=device,
|
||||
)
|
||||
|
||||
def get_rotary_pos_embed(
|
||||
def _get_rotary_pos_embed_for_range(
|
||||
self,
|
||||
vis_rope_size,
|
||||
txt_rope_size: Optional[int] = None,
|
||||
start: Tuple[int, int, int],
|
||||
stop: Tuple[int, int, int],
|
||||
device=None,
|
||||
):
|
||||
target_ndim = 3
|
||||
vis_rope_size = list(vis_rope_size)
|
||||
if len(vis_rope_size) != target_ndim:
|
||||
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
||||
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
|
||||
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
|
||||
head_dim = self.hidden_size // self.num_attention_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
||||
rope_dim_list = [head_dim // 3 for _ in range(3)]
|
||||
if sum(rope_dim_list) != head_dim:
|
||||
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
||||
|
||||
grid = torch.stack(
|
||||
torch.meshgrid(
|
||||
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
|
||||
indexing="ij",
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
|
||||
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
|
||||
|
||||
vis_cos, vis_sin = [], []
|
||||
cos_parts, sin_parts = [], []
|
||||
for i, dim in enumerate(rope_dim_list):
|
||||
pos = grid[i].reshape(-1)
|
||||
pos = mesh[i].reshape(-1)
|
||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||
freqs = torch.outer(pos.float(), freqs)
|
||||
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
||||
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
||||
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
|
||||
angles = torch.outer(pos, freqs)
|
||||
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
|
||||
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
|
||||
|
||||
if txt_rope_size is None:
|
||||
return vis_freqs, None
|
||||
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
|
||||
|
||||
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
|
||||
txt_cos, txt_sin = [], []
|
||||
for i, dim in enumerate(rope_dim_list):
|
||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||
freqs = torch.outer(grid_txt.float(), freqs)
|
||||
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
||||
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
||||
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
|
||||
|
||||
return vis_freqs, txt_freqs
|
||||
def get_rotary_pos_embed_for_components(
|
||||
self,
|
||||
component_sizes,
|
||||
device=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
|
||||
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
|
||||
# continues from the running offset, giving every image its own temporal position band.
|
||||
cos_parts, sin_parts = [], []
|
||||
t_offset = 0
|
||||
for (t, h, w) in component_sizes:
|
||||
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
|
||||
start=(t_offset, 0, 0),
|
||||
stop=(t_offset + t, h, w),
|
||||
device=device,
|
||||
)
|
||||
cos_parts.append(cos_emb)
|
||||
sin_parts.append(sin_emb)
|
||||
t_offset += t
|
||||
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
||||
c = self.out_channels
|
||||
@ -436,25 +434,57 @@ class JoyImageTransformer3DModel(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
ref_latents=None,
|
||||
) -> torch.Tensor:
|
||||
_, _, ot, oh, ow = hidden_states.shape
|
||||
tt = ot // self.patch_size[0]
|
||||
th = oh // self.patch_size[1]
|
||||
tw = ow // self.patch_size[2]
|
||||
# The target noise latent and each reference latent are independently patchified by img_in
|
||||
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
||||
# RoPE is built per component so references may differ in resolution. Only the leading
|
||||
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
|
||||
# A single reference is simply the len(ref_latents) == 1 case.
|
||||
if hidden_states.ndim != 5:
|
||||
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
|
||||
|
||||
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
||||
_, _, ot, oh, ow = hidden_states.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
|
||||
raise ValueError(
|
||||
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||
)
|
||||
tt = ot // pt
|
||||
th = oh // ph
|
||||
tw = ow // pw
|
||||
|
||||
components = [hidden_states]
|
||||
if ref_latents is not None:
|
||||
for r in ref_latents:
|
||||
if r.ndim != 5:
|
||||
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
|
||||
components.append(r)
|
||||
|
||||
component_sizes = []
|
||||
img_tokens = []
|
||||
for comp in components:
|
||||
_, _, ct, ch, cw = comp.shape
|
||||
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
|
||||
raise ValueError(
|
||||
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||
)
|
||||
component_sizes.append((ct // pt, ch // ph, cw // pw))
|
||||
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
|
||||
img_tokens.append(tokens)
|
||||
|
||||
img = torch.cat(img_tokens, dim=1)
|
||||
|
||||
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
if vec.shape[-1] > self.hidden_size:
|
||||
vec = vec.unflatten(1, (6, -1))
|
||||
|
||||
txt_seq_len = txt.shape[1]
|
||||
|
||||
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
|
||||
vis_rope_size=[tt, th, tw],
|
||||
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
|
||||
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
|
||||
component_sizes,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
vis_freqs = (vis_cos, vis_sin)
|
||||
txt_freqs = None
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(
|
||||
@ -465,5 +495,7 @@ class JoyImageTransformer3DModel(nn.Module):
|
||||
)
|
||||
|
||||
img = self.proj_out(self.norm_out(img))
|
||||
target_tokens = tt * th * tw
|
||||
img = img[:, :target_tokens, :]
|
||||
img = self.unpatchify(img, tt, th, tw)
|
||||
return img
|
||||
|
||||
@ -2131,8 +2131,9 @@ class QwenImage(BaseModel):
|
||||
return out
|
||||
|
||||
class JoyImage(BaseModel):
|
||||
# JoyImageEdit: 6D stacking + [last, first, ...] rotation, plus hard-wired guidance rescale,
|
||||
# are deliberately handled HERE (not in the transformer) so the transformer stays 5D-in / 5D-out.
|
||||
# The noise latent and every reference latent are concatenated as a token sequence inside the
|
||||
# transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG
|
||||
# guidance rescale is installed from here.
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
@ -2177,8 +2178,9 @@ class JoyImage(BaseModel):
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError(
|
||||
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
|
||||
"reference_latents. Connect the same image+vae into both TextEncodeJoyImageEdit nodes. "
|
||||
"Empty negative prompts still need image+vae wired."
|
||||
"reference_latents. Wire the same reference image(s) and vae into both the positive and "
|
||||
"negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative "
|
||||
"prompts still need the image(s) and vae."
|
||||
)
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
@ -2194,8 +2196,8 @@ class JoyImage(BaseModel):
|
||||
return out
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
# 6D stacking + [last, first, ...] rotation: bring noise (5D x) and the ref_latents (CONDList -> list)
|
||||
# into a single 5D tensor (B, C, n*T, H, W) where slot 0 along T is the noise after rotation.
|
||||
# Pass the noise latent and the reference latents to the transformer, which patchifies each
|
||||
# component and concatenates them along the sequence dim. References may be any resolution.
|
||||
if c_concat is not None:
|
||||
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
|
||||
self._ensure_guidance_rescale_installed()
|
||||
@ -2225,38 +2227,26 @@ class JoyImage(BaseModel):
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
|
||||
|
||||
# Build 6D (B, n, C, T, H, W) with refs first then noise, then rotate
|
||||
# [last, first, ...] so the noise moves to the front, and reshape to 5D (B, C, n*T, H, W).
|
||||
b, c, t_noise, h, w = xc.shape
|
||||
ref_5d = []
|
||||
if xc.ndim != 5:
|
||||
raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape)))
|
||||
|
||||
refs = []
|
||||
for r in ref_latents:
|
||||
if r.shape[-3:] != xc.shape[-3:]:
|
||||
if r.ndim != 5:
|
||||
raise ValueError(
|
||||
"JoyImageEdit: reference latent spatial/temporal shape {} must match noise {}.".format(
|
||||
tuple(r.shape), tuple(xc.shape)
|
||||
)
|
||||
"JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape))
|
||||
)
|
||||
ref_5d.append(r.to(device=device, dtype=dtype))
|
||||
stacked = torch.stack([*ref_5d, xc], dim=1) # (B, n, C, T, H, W)
|
||||
n = stacked.shape[1]
|
||||
rotated = torch.cat([stacked[:, -1:], stacked[:, :-1]], dim=1) # noise -> front
|
||||
flat = rotated.permute(0, 2, 1, 3, 4, 5).reshape(b, c, n * t_noise, h, w)
|
||||
refs.append(r.to(device=device, dtype=dtype))
|
||||
|
||||
if control is not None:
|
||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||
|
||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states); it does
|
||||
# not accept control/_options/extra_conds. Pass context positionally; the text-encoder
|
||||
# output IS what's threaded into encoder_hidden_states.
|
||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
||||
# ref_latents); it does not accept control/_options/other extra_conds.
|
||||
if extra_conds:
|
||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||
|
||||
model_output = self.diffusion_model(flat, t_in, context)
|
||||
|
||||
# After the rotation noise sat at slot 0; pluck it back out from the n*T axis.
|
||||
c_out = model_output.shape[1]
|
||||
out_6d = model_output.reshape(b, c_out, n, t_noise, h, w)
|
||||
noise_pred = out_6d[:, :, 0] # (B, C, T, H, W)
|
||||
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||
|
||||
|
||||
@ -13,9 +13,10 @@ import torch.nn.functional as F
|
||||
from comfy import sd1_clip
|
||||
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
|
||||
|
||||
# Prompt templates for the text-only and image-conditioned modes. The
|
||||
# image-conditioned template wraps the user text with a single
|
||||
# `<|vision_start|><|image_pad|><|vision_end|>` block; one user turn per call.
|
||||
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
|
||||
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
|
||||
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
|
||||
# `{prompt}` with the user text.
|
||||
JOYIMAGE_TEMPLATE_TEXT = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
@ -25,9 +26,12 @@ JOYIMAGE_TEMPLATE_TEXT = (
|
||||
JOYIMAGE_TEMPLATE_IMAGE = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
# A single vision block; N copies are concatenated to condition on N reference images.
|
||||
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
# Number of leading template tokens (system prompt + the user block's opening
|
||||
# `<|im_start|>`) stripped from the encoded output by
|
||||
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
|
||||
@ -165,12 +169,14 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
|
||||
"""JoyImageEdit tokenizer.
|
||||
|
||||
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
||||
template when one or more image tensors are passed, otherwise the text-only
|
||||
template. Each ``<|image_pad|>`` token in the formatted prompt is replaced
|
||||
with an embedding marker so `SDClipModel.process_tokens` routes the image
|
||||
through `Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading
|
||||
template tokens are stripped downstream by
|
||||
`JoyImageTEModel.encode_token_weights`. No ``<think>`` block is appended.
|
||||
template when one or more image tensors are passed, emitting one
|
||||
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
|
||||
for N reference images), otherwise the text-only template. Each
|
||||
``<|image_pad|>`` token in the formatted prompt is replaced with an
|
||||
embedding marker so `SDClipModel.process_tokens` routes each image through
|
||||
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
|
||||
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
|
||||
No ``<think>`` block is appended.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -188,7 +194,9 @@ class JoyImageTokenizer(Qwen3VLTokenizer):
|
||||
elif llama_template is not None:
|
||||
llama_text = llama_template.format(text)
|
||||
elif len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
# One vision block per reference image.
|
||||
vision = JOYIMAGE_VISION_BLOCK * len(images)
|
||||
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
|
||||
|
||||
@ -76,11 +76,80 @@ class TextEncodeJoyImageEdit(io.ComfyNode):
|
||||
return io.NodeOutput(conditioning, resized_image)
|
||||
|
||||
|
||||
class TextEncodeJoyImageEditPlus(io.ComfyNode):
|
||||
"""JoyImageEdit multi-image (Plus) text-encode node.
|
||||
|
||||
Accepts 1-6 optional reference images. Each supplied image is
|
||||
bucket-resized independently (same buckets/resize as the single-image
|
||||
node), VAE-encoded, and appended in order to
|
||||
``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...).
|
||||
All resized images are passed to the VL tower in one call; the tokenizer
|
||||
emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image.
|
||||
"""
|
||||
|
||||
MAX_IMAGES = 6
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeJoyImageEditPlus",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
io.Image.Input("image4", optional=True),
|
||||
io.Image.Input("image5", optional=True),
|
||||
io.Image.Input("image6", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None,
|
||||
image4=None, image5=None, image6=None) -> io.NodeOutput:
|
||||
images = [image1, image2, image3, image4, image5, image6]
|
||||
supplied = [img for img in images if img is not None]
|
||||
if len(supplied) == 0:
|
||||
raise ValueError(
|
||||
"TextEncodeJoyImageEditPlus requires at least one reference image."
|
||||
)
|
||||
|
||||
resized_images = []
|
||||
ref_latents = []
|
||||
for image in supplied:
|
||||
samples = image.movedim(-1, 1)
|
||||
src_h, src_w = samples.shape[2], samples.shape[3]
|
||||
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
|
||||
|
||||
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
|
||||
resized_image = resized.movedim(1, -1)[:, :, :, :3]
|
||||
resized_images.append(resized_image)
|
||||
ref_latents.append(vae.encode(resized_image))
|
||||
|
||||
tokens = clip.tokenize(prompt, images=resized_images)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(
|
||||
conditioning, {"reference_latents": ref_latents}, append=True,
|
||||
)
|
||||
|
||||
# The last reference sets the target resolution; return it for VAEEncode and the
|
||||
# matching negative encode.
|
||||
return io.NodeOutput(conditioning, resized_images[-1])
|
||||
|
||||
|
||||
class JoyImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeJoyImageEdit,
|
||||
TextEncodeJoyImageEditPlus,
|
||||
]
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user