mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Merge 5b6dfcbe46 into 96e0e3585b
This commit is contained in:
commit
9fcdcf8d69
549
comfy/ldm/joyimage/model.py
Normal file
549
comfy/ldm/joyimage/model.py
Normal file
@ -0,0 +1,549 @@
|
||||
# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
class FP32LayerNorm(nn.Module):
|
||||
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, int):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
ndim = xq.ndim
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
|
||||
cos = freqs_cis[0].view(*shape).to(xq.device)
|
||||
sin = freqs_cis[1].view(*shape).to(xq.device)
|
||||
|
||||
def _rotate_half(x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
class JoyImageModulate(nn.Module):
|
||||
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.modulate_table = nn.Parameter(
|
||||
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list:
|
||||
if x.ndim != 3:
|
||||
x = x.unsqueeze(1)
|
||||
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
|
||||
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
|
||||
|
||||
|
||||
class JoyImageFeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.net = nn.ModuleList([
|
||||
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
|
||||
nn.Dropout(0.0),
|
||||
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
|
||||
class _GeluApproximate(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.gelu(self.proj(x), approximate="tanh")
|
||||
|
||||
|
||||
class JoyImageAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
||||
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
||||
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
heads = self.num_attention_heads
|
||||
|
||||
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
|
||||
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
|
||||
|
||||
img_q = img_q.unflatten(-1, (heads, -1))
|
||||
img_k = img_k.unflatten(-1, (heads, -1))
|
||||
img_v = img_v.unflatten(-1, (heads, -1))
|
||||
txt_q = txt_q.unflatten(-1, (heads, -1))
|
||||
txt_k = txt_k.unflatten(-1, (heads, -1))
|
||||
txt_v = txt_v.unflatten(-1, (heads, -1))
|
||||
|
||||
img_q = self.img_attn_q_norm(img_q)
|
||||
img_k = self.img_attn_k_norm(img_k)
|
||||
txt_q = self.txt_attn_q_norm(txt_q)
|
||||
txt_k = self.txt_attn_k_norm(txt_k)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
vis_freqs, txt_freqs = image_rotary_emb
|
||||
if vis_freqs is not None:
|
||||
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
|
||||
if txt_freqs is not None:
|
||||
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
|
||||
|
||||
joint_q = torch.cat([img_q, txt_q], dim=1)
|
||||
joint_k = torch.cat([img_k, txt_k], dim=1)
|
||||
joint_v = torch.cat([img_v, txt_v], dim=1)
|
||||
|
||||
joint_q = joint_q.flatten(2, 3)
|
||||
joint_k = joint_k.flatten(2, 3)
|
||||
joint_v = joint_v.flatten(2, 3)
|
||||
|
||||
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads, transformer_options=transformer_options)
|
||||
joint_out = joint_out.to(joint_q.dtype)
|
||||
|
||||
seq_img = img.shape[1]
|
||||
img_out = joint_out[:, :seq_img, :]
|
||||
txt_out = joint_out[:, seq_img:, :]
|
||||
|
||||
img_out = self.img_attn_proj(img_out)
|
||||
txt_out = self.txt_attn_proj(txt_out)
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class JoyImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
eps: float = 1e-6,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
mlp_hidden_dim = int(dim * mlp_width_ratio)
|
||||
|
||||
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
||||
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.attn = JoyImageAttention(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=eps,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
(
|
||||
img_mod1_shift,
|
||||
img_mod1_scale,
|
||||
img_mod1_gate,
|
||||
img_mod2_shift,
|
||||
img_mod2_scale,
|
||||
img_mod2_gate,
|
||||
) = self.img_mod(temb)
|
||||
(
|
||||
txt_mod1_shift,
|
||||
txt_mod1_scale,
|
||||
txt_mod1_gate,
|
||||
txt_mod2_shift,
|
||||
txt_mod2_scale,
|
||||
txt_mod2_gate,
|
||||
) = self.txt_mod(temb)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
||||
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
||||
|
||||
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb, transformer_options=transformer_options)
|
||||
|
||||
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
||||
|
||||
img_ffn_normed = self.img_norm2(hidden_states)
|
||||
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
|
||||
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
|
||||
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class JoyImageTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(
|
||||
in_channels=time_freq_dim,
|
||||
time_embed_dim=dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
|
||||
self.text_embedder = _PixArtAlphaTextProjection(
|
||||
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
return temb, timestep_proj, encoder_hidden_states
|
||||
|
||||
|
||||
class _PixArtAlphaTextProjection(nn.Module):
|
||||
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, caption: torch.Tensor) -> torch.Tensor:
|
||||
return self.linear_2(self.act_1(self.linear_1(caption)))
|
||||
|
||||
|
||||
class JoyImageTransformer3DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: list = [1, 2, 2],
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 3072,
|
||||
num_attention_heads: int = 24,
|
||||
text_dim: int = 4096,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
num_layers: int = 20,
|
||||
rope_dim_list: list = [16, 56, 56],
|
||||
rope_type: str = "rope",
|
||||
theta: int = 256,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.patch_size = list(patch_size)
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.rope_dim_list = list(rope_dim_list)
|
||||
self.rope_type = rope_type
|
||||
self.theta = theta
|
||||
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
|
||||
)
|
||||
attention_head_dim = hidden_size // num_attention_heads
|
||||
if sum(self.rope_dim_list) != attention_head_dim:
|
||||
raise ValueError(
|
||||
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
|
||||
)
|
||||
|
||||
self.img_in = operations.Conv3d(
|
||||
in_channels,
|
||||
hidden_size,
|
||||
kernel_size=tuple(self.patch_size),
|
||||
stride=tuple(self.patch_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.condition_embedder = JoyImageTimeTextImageEmbedding(
|
||||
dim=hidden_size,
|
||||
time_freq_dim=256,
|
||||
time_proj_dim=hidden_size * 6,
|
||||
text_embed_dim=text_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.double_blocks = nn.ModuleList([
|
||||
JoyImageTransformerBlock(
|
||||
dim=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj_out = operations.Linear(
|
||||
hidden_size,
|
||||
self.out_channels * math.prod(self.patch_size),
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def _get_rotary_pos_embed_for_range(
|
||||
self,
|
||||
start: Tuple[int, int, int],
|
||||
stop: Tuple[int, int, int],
|
||||
device=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 3D RoPE for the patch grid range [start, stop) over (t, h, w). Token order after
|
||||
# reshape(-1) is (t, h, w), matching the img_in Conv3d flatten.
|
||||
head_dim = self.hidden_size // self.num_attention_heads
|
||||
rope_dim_list = self.rope_dim_list
|
||||
if rope_dim_list is None:
|
||||
rope_dim_list = [head_dim // 3 for _ in range(3)]
|
||||
if sum(rope_dim_list) != head_dim:
|
||||
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
||||
|
||||
grids = [torch.arange(start[i], stop[i], dtype=torch.float32, device=device) for i in range(3)]
|
||||
mesh = torch.stack(torch.meshgrid(*grids, indexing="ij"), dim=0)
|
||||
|
||||
cos_parts, sin_parts = [], []
|
||||
for i, dim in enumerate(rope_dim_list):
|
||||
pos = mesh[i].reshape(-1)
|
||||
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
||||
angles = torch.outer(pos, freqs)
|
||||
cos_parts.append(angles.cos().repeat_interleave(2, dim=1))
|
||||
sin_parts.append(angles.sin().repeat_interleave(2, dim=1))
|
||||
|
||||
return torch.cat(cos_parts, dim=1), torch.cat(sin_parts, dim=1)
|
||||
|
||||
def get_rotary_pos_embed_for_components(
|
||||
self,
|
||||
component_sizes,
|
||||
device=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Per-component 3D RoPE. component_sizes is a list of (t, h, w) patch grid sizes in
|
||||
# sequence order [target, ref0, ref1, ...]; h/w restart at 0 for each component while t
|
||||
# continues from the running offset, giving every image its own temporal position band.
|
||||
cos_parts, sin_parts = [], []
|
||||
t_offset = 0
|
||||
for (t, h, w) in component_sizes:
|
||||
cos_emb, sin_emb = self._get_rotary_pos_embed_for_range(
|
||||
start=(t_offset, 0, 0),
|
||||
stop=(t_offset + t, h, w),
|
||||
device=device,
|
||||
)
|
||||
cos_parts.append(cos_emb)
|
||||
sin_parts.append(sin_emb)
|
||||
t_offset += t
|
||||
return torch.cat(cos_parts, dim=0), torch.cat(sin_parts, dim=0)
|
||||
|
||||
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
||||
c = self.out_channels
|
||||
pt, ph, pw = self.patch_size
|
||||
if t * h * w != x.shape[1]:
|
||||
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
|
||||
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
|
||||
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(hidden_states, timestep, encoder_hidden_states, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
ref_latents=None,
|
||||
transformer_options={},
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# The target noise latent and each reference latent are independently patchified by img_in
|
||||
# (Conv3d) and concatenated along the sequence dim, in the order [target, ref0, ref1, ...].
|
||||
# RoPE is built per component so references may differ in resolution. Only the leading
|
||||
# target segment (tt*th*tw tokens) is projected back out; reference tokens are dropped.
|
||||
# A single reference is simply the len(ref_latents) == 1 case.
|
||||
if hidden_states.ndim != 5:
|
||||
raise ValueError(f"JoyImage transformer expects 5D (B,C,T,H,W) hidden_states; got shape {tuple(hidden_states.shape)}")
|
||||
|
||||
_, _, ot, oh, ow = hidden_states.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
if ot % pt != 0 or oh % ph != 0 or ow % pw != 0:
|
||||
raise ValueError(
|
||||
f"JoyImage: target latent spatial/temporal shape {(ot, oh, ow)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||
)
|
||||
tt = ot // pt
|
||||
th = oh // ph
|
||||
tw = ow // pw
|
||||
|
||||
components = [hidden_states]
|
||||
if ref_latents is not None:
|
||||
for r in ref_latents:
|
||||
if r.ndim != 5:
|
||||
raise ValueError(f"JoyImage: each reference latent must be 5D (B,C,T,H,W); got shape {tuple(r.shape)}")
|
||||
components.append(r)
|
||||
|
||||
component_sizes = []
|
||||
img_tokens = []
|
||||
for comp in components:
|
||||
_, _, ct, ch, cw = comp.shape
|
||||
if ct % pt != 0 or ch % ph != 0 or cw % pw != 0:
|
||||
raise ValueError(
|
||||
f"JoyImage: component shape {(ct, ch, cw)} must be divisible by patch_size {tuple(self.patch_size)}"
|
||||
)
|
||||
component_sizes.append((ct // pt, ch // ph, cw // pw))
|
||||
tokens = self.img_in(comp).flatten(2).transpose(1, 2) # (B, n_i, D)
|
||||
img_tokens.append(tokens)
|
||||
|
||||
img = torch.cat(img_tokens, dim=1)
|
||||
|
||||
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
if vec.shape[-1] > self.hidden_size:
|
||||
vec = vec.unflatten(1, (6, -1))
|
||||
|
||||
vis_cos, vis_sin = self.get_rotary_pos_embed_for_components(
|
||||
component_sizes,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
vis_freqs = (vis_cos, vis_sin)
|
||||
txt_freqs = None
|
||||
image_rotary_emb = (vis_freqs, txt_freqs)
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
image_rotary_emb=args["pe"],
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": image_rotary_emb,
|
||||
"transformer_options": transformer_options},
|
||||
{"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img = self.proj_out(self.norm_out(img))
|
||||
target_tokens = tt * th * tw
|
||||
img = img[:, :target_tokens, :]
|
||||
img = self.unpatchify(img, tt, th, tw)
|
||||
return img
|
||||
@ -57,6 +57,7 @@ import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.boogu.model
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.joyimage.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.krea2.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
@ -2264,6 +2265,126 @@ class QwenImage(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class JoyImage(BaseModel):
|
||||
# The noise latent and every reference latent are concatenated as a token sequence inside the
|
||||
# transformer. A single-reference edit is just the len(ref_latents) == 1 case. The built-in CFG
|
||||
# guidance rescale is installed from here.
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.joyimage.model.JoyImageTransformer3DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
@staticmethod
|
||||
def _guidance_rescale_cfg(args):
|
||||
# CFG combine + per-row L2 rescale in eps-space (guidance rescale).
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
comb = uncond + cond_scale * (cond - uncond)
|
||||
cond_norm = torch.norm(cond, dim=1, keepdim=True)
|
||||
comb_norm = torch.norm(comb, dim=1, keepdim=True)
|
||||
return comb * (cond_norm / comb_norm.clamp_min(1e-6))
|
||||
|
||||
def _ensure_guidance_rescale_installed(self):
|
||||
# Self-install the hard-wired guidance rescale once the patcher binds (sd.py doesn't expose a hook
|
||||
# for this; doing it here keeps the edit confined to model_base.py). Idempotent; refuses to install
|
||||
# if a different sampler_cfg_function is already present (e.g. a CFGNorm node) so the user's
|
||||
# override does not silently shadow JoyImage's required rescale.
|
||||
patcher = self.current_patcher
|
||||
if patcher is None:
|
||||
return
|
||||
existing = patcher.model_options.get("sampler_cfg_function", None)
|
||||
if existing is JoyImage._guidance_rescale_cfg:
|
||||
return
|
||||
if existing is not None:
|
||||
raise RuntimeError(
|
||||
"JoyImage requires its built-in CFG guidance-rescale function "
|
||||
"(comb * cond_norm / comb_norm); an external sampler_cfg_function "
|
||||
"(e.g. CFGNorm) is already installed and would override it. "
|
||||
"Remove the external function before sampling JoyImage."
|
||||
)
|
||||
patcher.set_model_sampler_cfg_function(JoyImage._guidance_rescale_cfg)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError(
|
||||
"JoyImageEdit is an edit model: every conditioning (positive AND negative) must carry "
|
||||
"reference_latents. Wire the same reference image(s) and vae into both the positive and "
|
||||
"negative TextEncodeJoyImageEdit / TextEncodeJoyImageEditPlus nodes. Empty negative "
|
||||
"prompts still need the image(s) and vae."
|
||||
)
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
# Pass the noise latent and the reference latents to the transformer, which patchifies each
|
||||
# component and concatenates them along the sequence dim. References may be any resolution.
|
||||
if c_concat is not None:
|
||||
raise ValueError("JoyImage does not support c_concat / noise_concat conditioning")
|
||||
self._ensure_guidance_rescale_installed()
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype_inference()
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
t_in = self.model_sampling.timestep(t).float()
|
||||
if context is not None:
|
||||
context = comfy.model_management.cast_to_device(context, device, dtype)
|
||||
|
||||
extra_conds = {}
|
||||
for o in kwargs:
|
||||
extra = kwargs[o]
|
||||
if hasattr(extra, "dtype"):
|
||||
extra = convert_tensor(extra, dtype, device)
|
||||
elif isinstance(extra, list):
|
||||
ex = []
|
||||
for ext in extra:
|
||||
ex.append(convert_tensor(ext, dtype, device))
|
||||
extra = ex
|
||||
extra_conds[o] = extra
|
||||
|
||||
ref_latents = extra_conds.pop("ref_latents", None)
|
||||
if ref_latents is None or len(ref_latents) == 0:
|
||||
raise ValueError("JoyImageEdit forward requires ref_latents; got none.")
|
||||
|
||||
if xc.ndim != 5:
|
||||
raise ValueError("JoyImageEdit: noise latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(xc.shape)))
|
||||
|
||||
refs = []
|
||||
for r in ref_latents:
|
||||
if r.ndim != 5:
|
||||
raise ValueError(
|
||||
"JoyImageEdit: each reference latent must be 5D (B,C,T,H,W); got shape {}.".format(tuple(r.shape))
|
||||
)
|
||||
refs.append(r.to(device=device, dtype=dtype))
|
||||
|
||||
if control is not None:
|
||||
raise ValueError("JoyImageEdit: control (ControlNet) is not supported by the transformer.")
|
||||
|
||||
# The transformer's forward signature is (hidden_states, timestep, encoder_hidden_states,
|
||||
# ref_latents, transformer_options); it does not accept control/other extra_conds.
|
||||
if extra_conds:
|
||||
raise ValueError("JoyImageEdit: unexpected extra_conds keys {} reached the transformer.".format(list(extra_conds.keys())))
|
||||
|
||||
noise_pred = self.diffusion_model(xc, t_in, context, ref_latents=refs, transformer_options=transformer_options)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, noise_pred.float(), x)
|
||||
|
||||
class Ideogram4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
|
||||
|
||||
@ -827,6 +827,27 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
# JoyImageEdit: dual-stream double_blocks with img_attn_qkv, a condition_embedder
|
||||
# time_embedder, and a 5D Conv3d img_in (kernel [1,2,2]).
|
||||
if (
|
||||
'{}double_blocks.0.attn.img_attn_qkv.weight'.format(key_prefix) in state_dict_keys
|
||||
and '{}condition_embedder.time_embedder.linear_1.weight'.format(key_prefix) in state_dict_keys
|
||||
and '{}img_in.weight'.format(key_prefix) in state_dict_keys
|
||||
and len(state_dict['{}img_in.weight'.format(key_prefix)].shape) == 5
|
||||
):
|
||||
img_in = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "joyimage"
|
||||
dit_config["in_channels"] = img_in.shape[1]
|
||||
dit_config["hidden_size"] = img_in.shape[0]
|
||||
dit_config["patch_size"] = list(img_in.shape[2:])
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
head_dim = state_dict['{}double_blocks.0.attn.img_attn_q_norm.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_attention_heads"] = dit_config["hidden_size"] // head_dim
|
||||
# text_dim from the text-embedder input projection
|
||||
dit_config["text_dim"] = state_dict['{}condition_embedder.text_embedder.linear_1.weight'.format(key_prefix)].shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ideogram4"
|
||||
|
||||
11
comfy/sd.py
11
comfy/sd.py
@ -75,6 +75,7 @@ import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
import comfy.text_encoders.joyimage
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1305,6 +1306,7 @@ class CLIPType(Enum):
|
||||
IDEOGRAM4 = 30
|
||||
BOOGU = 31
|
||||
KREA2 = 32
|
||||
JOYIMAGE = 33
|
||||
|
||||
|
||||
|
||||
@ -1360,6 +1362,7 @@ class TEModel(Enum):
|
||||
GPT_OSS_20B = 33
|
||||
QWEN3VL_4B = 34
|
||||
QWEN3VL_8B = 35
|
||||
QWEN3VL_8B_JOYIMAGE = 36
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1421,6 +1424,8 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.language_model.layers.0.self_attn.q_norm.weight" in sd and "model.visual.patch_embed.proj.weight" in sd:
|
||||
return TEModel.QWEN3VL_8B_JOYIMAGE
|
||||
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
|
||||
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
@ -1643,6 +1648,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
|
||||
elif te_model == TEModel.QWEN3VL_8B_JOYIMAGE:
|
||||
# Remap the HF Qwen3VLForConditionalGeneration layout to the Qwen3VL
|
||||
# namespace (model.*, visual.*, model.lm_head.*).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.joyimage.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.joyimage.JoyImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@ -1877,6 +1877,45 @@ class QwenImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
|
||||
|
||||
class JoyImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "joyimage",
|
||||
}
|
||||
|
||||
# multiplier=1000: the transformer's time embedding is trained on t in [0,1000].
|
||||
# ModelSamplingDiscreteFlow.timestep(sigma)=sigma*multiplier yields that range; the
|
||||
# multiplier cancels in the sigma table, so it only rescales the timestep value.
|
||||
sampling_settings = {
|
||||
"multiplier": 1000,
|
||||
"shift": 1.5,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.8
|
||||
|
||||
unet_extra_config = {
|
||||
"theta": 10000,
|
||||
"rope_dim_list": [16, 56, 56],
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Wan21 # AutoencoderKLWan: z_dim=16, scale_factor_spatial=8, scale_factor_temporal=4.
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.JoyImage(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
# Imported lazily so this module stays importable without the text-encoder deps loaded;
|
||||
# the import is only resolved when a checkpoint is actually loaded.
|
||||
import comfy.text_encoders.joyimage
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
qwen3vl_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.joyimage.JoyImageTokenizer, comfy.text_encoders.joyimage.te(**qwen3vl_detect))
|
||||
|
||||
class HunyuanImage21(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
@ -2354,6 +2393,7 @@ models = [
|
||||
Omnigen2,
|
||||
Boogu,
|
||||
QwenImage,
|
||||
JoyImage,
|
||||
Ideogram4,
|
||||
Krea2,
|
||||
Flux2,
|
||||
|
||||
280
comfy/text_encoders/joyimage.py
Normal file
280
comfy/text_encoders/joyimage.py
Normal file
@ -0,0 +1,280 @@
|
||||
"""JoyImageEdit text encoder: a stock Qwen3-VL-8B multimodal stack feeding the
|
||||
JoyImageEdit DiT, built on `comfy.text_encoders.qwen3vl` with the
|
||||
JoyImage-specific prompt templates, system-prompt strip, image preprocessing,
|
||||
and conditioning-path multimodal handling.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy import sd1_clip
|
||||
from comfy.text_encoders.qwen3vl import Qwen3VL, Qwen3VLTokenizer
|
||||
|
||||
# Prompt templates for the text-only and image-conditioned modes. The image-conditioned template
|
||||
# wraps the user text with one `<|vision_start|><|image_pad|><|vision_end|>` block per reference
|
||||
# image (no separator between blocks); `{vision}` is filled with the N concatenated blocks and
|
||||
# `{prompt}` with the user text.
|
||||
JOYIMAGE_TEMPLATE_TEXT = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
JOYIMAGE_TEMPLATE_IMAGE = (
|
||||
"<|im_start|>system\n \\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n{vision}{prompt}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
# A single vision block; N copies are concatenated to condition on N reference images.
|
||||
JOYIMAGE_VISION_BLOCK = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
|
||||
# Number of leading template tokens (system prompt + the user block's opening
|
||||
# `<|im_start|>`) stripped from the encoded output by
|
||||
# JoyImageTEModel.encode_token_weights, so the kept sequence begins at the
|
||||
# `user` token.
|
||||
JOYIMAGE_DROP_IDX = 34
|
||||
|
||||
# Special-token ids (vocab shared with Qwen2.5 / Qwen3, vocab_size 151936).
|
||||
IMAGE_PAD_TOKEN = 151655
|
||||
PAD_TOKEN = 151643
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image preprocessing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def process_qwen3vl_image(
|
||||
image: torch.Tensor,
|
||||
min_pixels: int = 65536,
|
||||
max_pixels: int = 16777216,
|
||||
patch_size: int = 16,
|
||||
temporal_patch_size: int = 2,
|
||||
merge_size: int = 2,
|
||||
image_mean: Optional[List[float]] = None,
|
||||
image_std: Optional[List[float]] = None,
|
||||
):
|
||||
"""Resize, normalize and patch-flatten a single (B=1, H, W, C) image tensor in [0, 1].
|
||||
|
||||
Returns ``(flatten_patches, grid_thw)`` ready for the Qwen3-VL vision tower.
|
||||
Uses bicubic interpolation followed by ``clamp(0, 1)``.
|
||||
"""
|
||||
if image_mean is None:
|
||||
image_mean = [0.5, 0.5, 0.5]
|
||||
if image_std is None:
|
||||
image_std = [0.5, 0.5, 0.5]
|
||||
|
||||
if image.dim() == 3:
|
||||
image = image.unsqueeze(0)
|
||||
batch, height, width, channels = image.shape
|
||||
if batch != 1:
|
||||
raise ValueError("process_qwen3vl_image expects one image (B=1) at a time.")
|
||||
device = image.device
|
||||
|
||||
image = image.permute(0, 3, 1, 2) # (1, C, H, W)
|
||||
img = image[0]
|
||||
|
||||
factor = patch_size * merge_size
|
||||
h_bar = round(height / factor) * factor
|
||||
w_bar = round(width / factor) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, math.floor(height / beta / factor) * factor)
|
||||
w_bar = max(factor, math.floor(width / beta / factor) * factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
|
||||
img_resized = F.interpolate(
|
||||
img.unsqueeze(0), size=(h_bar, w_bar), mode="bicubic", align_corners=False,
|
||||
).squeeze(0).clamp(0.0, 1.0)
|
||||
|
||||
normalized = img_resized.clone()
|
||||
for c in range(3):
|
||||
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
|
||||
|
||||
grid_h = h_bar // patch_size
|
||||
grid_w = w_bar // patch_size
|
||||
grid_thw = torch.tensor([[1, grid_h, grid_w]], device=device, dtype=torch.long)
|
||||
|
||||
# Single-frame inputs are duplicated along T to fill the 2-frame temporal
|
||||
# patch kernel; matches Qwen2VLImageProcessorFast for static images.
|
||||
pixel_values = normalized.unsqueeze(0).repeat(temporal_patch_size, 1, 1, 1)
|
||||
grid_t = 1
|
||||
channel = pixel_values.shape[1]
|
||||
patches = pixel_values.reshape(
|
||||
grid_t, temporal_patch_size, channel,
|
||||
grid_h // merge_size, merge_size, patch_size,
|
||||
grid_w // merge_size, merge_size, patch_size,
|
||||
)
|
||||
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
||||
flatten_patches = patches.reshape(
|
||||
grid_t * grid_h * grid_w,
|
||||
channel * temporal_patch_size * patch_size * patch_size,
|
||||
)
|
||||
return flatten_patches, grid_thw
|
||||
|
||||
|
||||
class Qwen3VL8B_JoyImage(Qwen3VL):
|
||||
"""JoyImage Qwen3-VL-8B encoder.
|
||||
|
||||
Stock `qwen3vl_8b` config (text dims 4096 / 36L / 32H / 8 kv; interleaved
|
||||
3D MRoPE rope_dims=[24,20,20], rope_theta=5e6; vision 1152/4304, depth 27,
|
||||
patch_size 16, deepstack_visual_indexes=[8,16,24]).
|
||||
"""
|
||||
|
||||
model_type = "qwen3vl_8b"
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
# Run the vision tower with JoyImage's bicubic+clamp preprocessing and
|
||||
# return ``(merged, {"grid", "deepstack"})``.
|
||||
if embed["type"] == "image":
|
||||
image, grid = process_qwen3vl_image(
|
||||
embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5],
|
||||
)
|
||||
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
|
||||
return merged, {"grid": grid, "deepstack": deepstack}
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None,
|
||||
intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, embeds_info=()):
|
||||
# The conditioning path must build the 3D MRoPE position ids for the
|
||||
# image-token block and inject the deepstack visual features.
|
||||
# `build_image_inputs` returns the kwargs the decoder expects:
|
||||
# (position_ids, visual_pos_masks, deepstack).
|
||||
if embeds is not None:
|
||||
position_ids, visual_pos_masks, deepstack = self.build_image_inputs(embeds, embeds_info)
|
||||
else:
|
||||
position_ids, visual_pos_masks, deepstack = None, None, None
|
||||
return self.model(
|
||||
x,
|
||||
attention_mask=attention_mask,
|
||||
embeds=embeds,
|
||||
num_tokens=num_tokens,
|
||||
intermediate_output=intermediate_output,
|
||||
final_layer_norm_intermediate=final_layer_norm_intermediate,
|
||||
dtype=dtype,
|
||||
position_ids=position_ids,
|
||||
deepstack_embeds=deepstack,
|
||||
visual_pos_masks=visual_pos_masks,
|
||||
)
|
||||
|
||||
|
||||
class JoyImageTokenizer(Qwen3VLTokenizer):
|
||||
"""JoyImageEdit tokenizer.
|
||||
|
||||
``tokenize_with_weights(text, images=[...])`` selects the image-conditioned
|
||||
template when one or more image tensors are passed, emitting one
|
||||
``<|vision_start|><|image_pad|><|vision_end|>`` block per image (N blocks
|
||||
for N reference images), otherwise the text-only template. Each
|
||||
``<|image_pad|>`` token in the formatted prompt is replaced with an
|
||||
embedding marker so `SDClipModel.process_tokens` routes each image through
|
||||
`Qwen3VL8B_JoyImage.preprocess_embed`; ``drop_idx=34`` leading template
|
||||
tokens are stripped downstream by `JoyImageTEModel.encode_token_weights`.
|
||||
No ``<think>`` block is appended.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
model_type="qwen3vl_8b",
|
||||
)
|
||||
self.llama_template = JOYIMAGE_TEMPLATE_TEXT
|
||||
self.llama_template_images = JOYIMAGE_TEMPLATE_IMAGE
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,
|
||||
images=[], **kwargs):
|
||||
if text.startswith("<|im_start|>"):
|
||||
llama_text = text
|
||||
elif llama_template is not None:
|
||||
llama_text = llama_template.format(text)
|
||||
elif len(images) > 0:
|
||||
# One vision block per reference image.
|
||||
vision = JOYIMAGE_VISION_BLOCK * len(images)
|
||||
llama_text = self.llama_template_images.format(vision=vision, prompt=text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
|
||||
# Tokenize the already-rendered template via the grandparent
|
||||
# (SD1Tokenizer); calling `super()` would re-apply the Qwen3VL template.
|
||||
tokens = sd1_clip.SD1Tokenizer.tokenize_with_weights(
|
||||
self, llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs,
|
||||
)
|
||||
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
qwen_tokens = tokens[key_name]
|
||||
for r in qwen_tokens:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == IMAGE_PAD_TOKEN:
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count],
|
||||
"original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
if embed_count != len(images):
|
||||
raise ValueError(
|
||||
f"JoyImageTokenizer: prompt had {embed_count} <|image_pad|> placeholders "
|
||||
f"but {len(images)} image(s) were supplied. Either pre-format the prompt "
|
||||
f"with `<|vision_start|><|image_pad|><|vision_end|>` per image or pass an "
|
||||
f"image-free prompt."
|
||||
)
|
||||
return tokens
|
||||
|
||||
|
||||
class _JoyImageClipModel(sd1_clip.SDClipModel):
|
||||
"""Qwen3-VL multimodal encoder wrapper.
|
||||
|
||||
Conditions on the **pre-final-norm** output of the last decoder layer
|
||||
(``layer="hidden", layer_idx=-1, layer_norm_hidden_state=False``). The
|
||||
post-norm ``last_hidden_state`` differs by ~10x in scale and produces broken
|
||||
DiT outputs, so these flags must not be changed.
|
||||
"""
|
||||
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None,
|
||||
attention_mask=True, model_options={}):
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": PAD_TOKEN}, layer_norm_hidden_state=False,
|
||||
model_class=Qwen3VL8B_JoyImage, enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask, model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
class JoyImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(
|
||||
device=device, dtype=dtype, name="qwen3vl_8b",
|
||||
clip_model=_JoyImageClipModel, model_options=model_options,
|
||||
)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
# Strip the JOYIMAGE_DROP_IDX-token system-prompt prefix from both the
|
||||
# embedding sequence and the attention mask.
|
||||
if out.shape[1] <= JOYIMAGE_DROP_IDX:
|
||||
raise ValueError(
|
||||
f"JoyImageTEModel: encoded sequence length {out.shape[1]} is shorter "
|
||||
f"than drop_idx={JOYIMAGE_DROP_IDX}; the prompt did not include the "
|
||||
f"template prefix."
|
||||
)
|
||||
out = out[:, JOYIMAGE_DROP_IDX:]
|
||||
if "attention_mask" in extra:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, JOYIMAGE_DROP_IDX:]
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class JoyImageTEModel_(JoyImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return JoyImageTEModel_
|
||||
157
comfy_extras/nodes_joyimage.py
Normal file
157
comfy_extras/nodes_joyimage.py
Normal file
@ -0,0 +1,157 @@
|
||||
import node_helpers
|
||||
import comfy.utils
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
# fmt: off
|
||||
BUCKETS_1024 = [
|
||||
(512, 1792), (512, 1856), (512, 1920), (512, 1984), (512, 2048),
|
||||
(576, 1600), (576, 1664), (576, 1728), (576, 1792),
|
||||
(640, 1472), (640, 1536), (640, 1600),
|
||||
(704, 1344), (704, 1408), (704, 1472),
|
||||
(768, 1216), (768, 1280), (768, 1344),
|
||||
(832, 1152), (832, 1216),
|
||||
(896, 1088), (896, 1152),
|
||||
(960, 1024), (960, 1088),
|
||||
(1024, 960), (1024, 1024),
|
||||
(1088, 896), (1088, 960),
|
||||
(1152, 832), (1152, 896),
|
||||
(1216, 768), (1216, 832),
|
||||
(1280, 768),
|
||||
(1344, 704), (1344, 768),
|
||||
(1408, 704),
|
||||
(1472, 640), (1472, 704),
|
||||
(1536, 640),
|
||||
(1600, 576), (1600, 640),
|
||||
(1664, 576),
|
||||
(1728, 576),
|
||||
(1792, 512), (1792, 576),
|
||||
(1856, 512),
|
||||
(1920, 512),
|
||||
(1984, 512),
|
||||
(2048, 512),
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def _find_best_bucket(height: int, width: int) -> tuple[int, int]:
|
||||
target_ratio = height / width
|
||||
return min(BUCKETS_1024, key=lambda hw: abs(hw[0] / hw[1] - target_ratio))
|
||||
|
||||
|
||||
class TextEncodeJoyImageEdit(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeJoyImageEdit",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae, image) -> io.NodeOutput:
|
||||
samples = image.movedim(-1, 1)
|
||||
src_h, src_w = samples.shape[2], samples.shape[3]
|
||||
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
|
||||
|
||||
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
|
||||
resized_image = resized.movedim(1, -1)[:, :, :, :3]
|
||||
|
||||
tokens = clip.tokenize(prompt, images=[resized_image])
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
ref_latent = vae.encode(resized_image)
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning, resized_image)
|
||||
|
||||
|
||||
class TextEncodeJoyImageEditPlus(io.ComfyNode):
|
||||
"""JoyImageEdit multi-image (Plus) text-encode node.
|
||||
|
||||
Accepts 1-6 optional reference images. Each supplied image is
|
||||
bucket-resized independently (same buckets/resize as the single-image
|
||||
node), VAE-encoded, and appended in order to
|
||||
``conditioning["reference_latents"]`` (image1 → ref0, image2 → ref1, ...).
|
||||
All resized images are passed to the VL tower in one call; the tokenizer
|
||||
emits one ``<|vision_start|><|image_pad|><|vision_end|>`` block per image.
|
||||
"""
|
||||
|
||||
MAX_IMAGES = 6
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeJoyImageEditPlus",
|
||||
category="advanced/conditioning",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
io.Image.Input("image4", optional=True),
|
||||
io.Image.Input("image5", optional=True),
|
||||
io.Image.Input("image6", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
io.Image.Output(display_name="image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, vae, image1=None, image2=None, image3=None,
|
||||
image4=None, image5=None, image6=None) -> io.NodeOutput:
|
||||
images = [image1, image2, image3, image4, image5, image6]
|
||||
supplied = [img for img in images if img is not None]
|
||||
if len(supplied) == 0:
|
||||
raise ValueError(
|
||||
"TextEncodeJoyImageEditPlus requires at least one reference image."
|
||||
)
|
||||
|
||||
resized_images = []
|
||||
ref_latents = []
|
||||
for image in supplied:
|
||||
samples = image.movedim(-1, 1)
|
||||
src_h, src_w = samples.shape[2], samples.shape[3]
|
||||
bucket_h, bucket_w = _find_best_bucket(src_h, src_w)
|
||||
|
||||
resized = comfy.utils.common_upscale(samples, bucket_w, bucket_h, "bilinear", "center")
|
||||
resized_image = resized.movedim(1, -1)[:, :, :, :3]
|
||||
resized_images.append(resized_image)
|
||||
ref_latents.append(vae.encode(resized_image))
|
||||
|
||||
tokens = clip.tokenize(prompt, images=resized_images)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
conditioning = node_helpers.conditioning_set_values(
|
||||
conditioning, {"reference_latents": ref_latents}, append=True,
|
||||
)
|
||||
|
||||
# The last reference sets the target resolution; return it for VAEEncode and the
|
||||
# matching negative encode.
|
||||
return io.NodeOutput(conditioning, resized_images[-1])
|
||||
|
||||
|
||||
class JoyImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeJoyImageEdit,
|
||||
TextEncodeJoyImageEditPlus,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> JoyImageExtension:
|
||||
return JoyImageExtension()
|
||||
3
nodes.py
3
nodes.py
@ -992,7 +992,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2", "joyimage"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -2460,6 +2460,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_tcfg.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_joyimage.py",
|
||||
"nodes_boogu.py",
|
||||
"nodes_chroma_radiance.py",
|
||||
"nodes_pid.py",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user