mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
JoyImageEdit is an image-edit diffusion transformer from JD (jd-opensource),
Apache 2.0. This adds native ComfyUI support so it loads and runs like other
edit models (load checkpoint -> TextEncode + ReferenceLatent -> KSampler ->
VAEDecode), with no diffusers dependency.
Architecture:
- Transformer (comfy/ldm/joyimage/model.py): dual-stream (img/txt) DiT with a
Conv3d patch embed (patch_size [1,2,2]), Wan-style learnable modulation,
and 3D RoPE (rope_dim_list [16,56,56]). All attention goes through
comfy.ldm.modules.attention.optimized_attention.
- Text encoder (comfy/text_encoders/{qwen3_vl,joyimage}.py): a reusable
Qwen3-VL multimodal stack (vision tower + LM) in qwen3_vl.py, plus a thin
JoyImage-specific layer (prompt templates, drop_idx, tokenizer, te() factory)
in joyimage.py that depends on it. text_dim 4096.
- VAE: reuses the existing Wan 2.1 latent format (AutoencoderKLWan), no new
latent format.
- Edit conditioning: reuses the reference_latents mechanism. Reference and
noise latents are stacked on a new n-slot dimension and rotated at the model
boundary (model_base.JoyImage), so the transformer stays 5D-in/5D-out.
Guidance-rescale is built into the CFG path.
Model wiring:
- model_base.JoyImage uses ModelType.FLOW with sampling_settings
multiplier=1000 (the time embedding is trained on t in [0,1000]) and
shift=1.5; FLOW's linear time_snr_shift matches the diffusers
FlowMatchEuler sigma schedule.
- model_detection sniffs the transformer state-dict (double_blocks.*,
condition_embedder.*, 5D img_in Conv3d) to route image_model="joyimage".
- supported_models.JoyImage and the CLIPLoader "joyimage" type register it.
User-facing node TextEncodeJoyImageEdit (comfy_extras/nodes_joyimage.py)
bucket-resizes the input image to the nearest 1024-base bucket, encodes the
prompt with the image, and emits both the conditioning and the bucketed image
so the same pixels feed VAEEncode and the negative encode (JoyImage requires
noise and reference latents to share spatial dims).
470 lines
18 KiB
Python
470 lines
18 KiB
Python
# https://github.com/jdopensource/JoyAI-Image-Edit (Apache 2.0)
|
|
import math
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
|
from comfy.ldm.modules.attention import optimized_attention
|
|
|
|
|
|
class FP32LayerNorm(nn.Module):
|
|
def __init__(self, normalized_shape, eps: float = 1e-6, dtype=None, device=None):
|
|
super().__init__()
|
|
if isinstance(normalized_shape, int):
|
|
normalized_shape = (normalized_shape,)
|
|
self.normalized_shape = tuple(normalized_shape)
|
|
self.eps = eps
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
orig_dtype = x.dtype
|
|
out = F.layer_norm(x.float(), self.normalized_shape, None, None, self.eps)
|
|
return out.to(orig_dtype)
|
|
|
|
|
|
def _apply_rotary_emb(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
ndim = xq.ndim
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(xq.shape)]
|
|
cos = freqs_cis[0].view(*shape).to(xq.device)
|
|
sin = freqs_cis[1].view(*shape).to(xq.device)
|
|
|
|
def _rotate_half(x):
|
|
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
|
|
|
xq_out = (xq.float() * cos + _rotate_half(xq) * sin).type_as(xq)
|
|
xk_out = (xk.float() * cos + _rotate_half(xk) * sin).type_as(xk)
|
|
return xq_out, xk_out
|
|
|
|
|
|
class JoyImageModulate(nn.Module):
|
|
def __init__(self, hidden_size: int, factor: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.factor = factor
|
|
self.modulate_table = nn.Parameter(
|
|
torch.zeros(1, factor, hidden_size, dtype=dtype, device=device)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> list:
|
|
if x.ndim != 3:
|
|
x = x.unsqueeze(1)
|
|
table = self.modulate_table.to(dtype=x.dtype, device=x.device)
|
|
return [o.squeeze(1) for o in (table + x).chunk(self.factor, dim=1)]
|
|
|
|
|
|
class JoyImageFeedForward(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
inner_dim: int,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.net = nn.ModuleList([
|
|
_GeluApproximate(dim, inner_dim, dtype=dtype, device=device, operations=operations),
|
|
nn.Dropout(0.0),
|
|
operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device),
|
|
])
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
for module in self.net:
|
|
x = module(x)
|
|
return x
|
|
|
|
|
|
class _GeluApproximate(nn.Module):
|
|
def __init__(self, dim_in: int, dim_out: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.proj = operations.Linear(dim_in, dim_out, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return F.gelu(self.proj(x), approximate="tanh")
|
|
|
|
|
|
class JoyImageAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
eps: float = 1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.num_attention_heads = num_attention_heads
|
|
inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
self.img_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
|
self.img_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
|
self.img_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
|
self.img_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
|
|
|
self.txt_attn_qkv = operations.Linear(dim, inner_dim * 3, bias=True, dtype=dtype, device=device)
|
|
self.txt_attn_q_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
|
self.txt_attn_k_norm = operations.RMSNorm(attention_head_dim, eps=eps, dtype=dtype, device=device)
|
|
self.txt_attn_proj = operations.Linear(inner_dim, dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
img: torch.Tensor,
|
|
txt: torch.Tensor,
|
|
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]],
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
heads = self.num_attention_heads
|
|
|
|
img_q, img_k, img_v = self.img_attn_qkv(img).chunk(3, dim=-1)
|
|
txt_q, txt_k, txt_v = self.txt_attn_qkv(txt).chunk(3, dim=-1)
|
|
|
|
img_q = img_q.unflatten(-1, (heads, -1))
|
|
img_k = img_k.unflatten(-1, (heads, -1))
|
|
img_v = img_v.unflatten(-1, (heads, -1))
|
|
txt_q = txt_q.unflatten(-1, (heads, -1))
|
|
txt_k = txt_k.unflatten(-1, (heads, -1))
|
|
txt_v = txt_v.unflatten(-1, (heads, -1))
|
|
|
|
img_q = self.img_attn_q_norm(img_q)
|
|
img_k = self.img_attn_k_norm(img_k)
|
|
txt_q = self.txt_attn_q_norm(txt_q)
|
|
txt_k = self.txt_attn_k_norm(txt_k)
|
|
|
|
if image_rotary_emb is not None:
|
|
vis_freqs, txt_freqs = image_rotary_emb
|
|
if vis_freqs is not None:
|
|
img_q, img_k = _apply_rotary_emb(img_q, img_k, vis_freqs)
|
|
if txt_freqs is not None:
|
|
txt_q, txt_k = _apply_rotary_emb(txt_q, txt_k, txt_freqs)
|
|
|
|
joint_q = torch.cat([img_q, txt_q], dim=1)
|
|
joint_k = torch.cat([img_k, txt_k], dim=1)
|
|
joint_v = torch.cat([img_v, txt_v], dim=1)
|
|
|
|
joint_q = joint_q.flatten(2, 3)
|
|
joint_k = joint_k.flatten(2, 3)
|
|
joint_v = joint_v.flatten(2, 3)
|
|
|
|
joint_out = optimized_attention(joint_q, joint_k, joint_v, heads=heads)
|
|
joint_out = joint_out.to(joint_q.dtype)
|
|
|
|
seq_img = img.shape[1]
|
|
img_out = joint_out[:, :seq_img, :]
|
|
txt_out = joint_out[:, seq_img:, :]
|
|
|
|
img_out = self.img_attn_proj(img_out)
|
|
txt_out = self.txt_attn_proj(txt_out)
|
|
return img_out, txt_out
|
|
|
|
|
|
class JoyImageTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_attention_heads: int,
|
|
attention_head_dim: int,
|
|
mlp_width_ratio: float = 4.0,
|
|
eps: float = 1e-6,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.num_attention_heads = num_attention_heads
|
|
self.attention_head_dim = attention_head_dim
|
|
mlp_hidden_dim = int(dim * mlp_width_ratio)
|
|
|
|
self.img_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
|
self.img_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
|
self.img_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
|
self.img_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.txt_mod = JoyImageModulate(dim, factor=6, dtype=dtype, device=device, operations=operations)
|
|
self.txt_norm1 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
|
self.txt_norm2 = FP32LayerNorm(dim, eps=eps, dtype=dtype, device=device)
|
|
self.txt_mlp = JoyImageFeedForward(dim, inner_dim=mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.attn = JoyImageAttention(
|
|
dim=dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
eps=eps,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
temb: torch.Tensor,
|
|
image_rotary_emb: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], Optional[Tuple[torch.Tensor, torch.Tensor]]]] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
(
|
|
img_mod1_shift,
|
|
img_mod1_scale,
|
|
img_mod1_gate,
|
|
img_mod2_shift,
|
|
img_mod2_scale,
|
|
img_mod2_gate,
|
|
) = self.img_mod(temb)
|
|
(
|
|
txt_mod1_shift,
|
|
txt_mod1_scale,
|
|
txt_mod1_gate,
|
|
txt_mod2_shift,
|
|
txt_mod2_scale,
|
|
txt_mod2_gate,
|
|
) = self.txt_mod(temb)
|
|
|
|
img_normed = self.img_norm1(hidden_states)
|
|
txt_normed = self.txt_norm1(encoder_hidden_states)
|
|
img_modulated = img_normed * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
|
|
txt_modulated = txt_normed * (1 + txt_mod1_scale.unsqueeze(1)) + txt_mod1_shift.unsqueeze(1)
|
|
|
|
img_attn, txt_attn = self.attn(img_modulated, txt_modulated, image_rotary_emb)
|
|
|
|
hidden_states = hidden_states + img_attn * img_mod1_gate.unsqueeze(1)
|
|
encoder_hidden_states = encoder_hidden_states + txt_attn * txt_mod1_gate.unsqueeze(1)
|
|
|
|
img_ffn_normed = self.img_norm2(hidden_states)
|
|
txt_ffn_normed = self.txt_norm2(encoder_hidden_states)
|
|
img_ffn_input = img_ffn_normed * (1 + img_mod2_scale.unsqueeze(1)) + img_mod2_shift.unsqueeze(1)
|
|
txt_ffn_input = txt_ffn_normed * (1 + txt_mod2_scale.unsqueeze(1)) + txt_mod2_shift.unsqueeze(1)
|
|
hidden_states = hidden_states + self.img_mlp(img_ffn_input) * img_mod2_gate.unsqueeze(1)
|
|
encoder_hidden_states = encoder_hidden_states + self.txt_mlp(txt_ffn_input) * txt_mod2_gate.unsqueeze(1)
|
|
|
|
return hidden_states, encoder_hidden_states
|
|
|
|
|
|
class JoyImageTimeTextImageEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
time_freq_dim: int,
|
|
time_proj_dim: int,
|
|
text_embed_dim: int,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.time_embedder = TimestepEmbedding(
|
|
in_channels=time_freq_dim,
|
|
time_embed_dim=dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.act_fn = nn.SiLU()
|
|
self.time_proj = operations.Linear(dim, time_proj_dim, bias=True, dtype=dtype, device=device)
|
|
self.text_embedder = _PixArtAlphaTextProjection(
|
|
text_embed_dim, dim, dtype=dtype, device=device, operations=operations,
|
|
)
|
|
|
|
def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
|
timestep = self.timesteps_proj(timestep)
|
|
temb = self.time_embedder(timestep.to(dtype=encoder_hidden_states.dtype)).type_as(encoder_hidden_states)
|
|
timestep_proj = self.time_proj(self.act_fn(temb))
|
|
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
|
return temb, timestep_proj, encoder_hidden_states
|
|
|
|
|
|
class _PixArtAlphaTextProjection(nn.Module):
|
|
def __init__(self, in_features: int, hidden_size: int, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.linear_1 = operations.Linear(in_features, hidden_size, bias=True, dtype=dtype, device=device)
|
|
self.act_1 = nn.GELU(approximate="tanh")
|
|
self.linear_2 = operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(self, caption: torch.Tensor) -> torch.Tensor:
|
|
return self.linear_2(self.act_1(self.linear_1(caption)))
|
|
|
|
|
|
class JoyImageTransformer3DModel(nn.Module):
|
|
# 6D->5D rotation and reshape happen in JoyImage.apply_model; this module is 5D-in, 5D-out.
|
|
|
|
def __init__(
|
|
self,
|
|
patch_size: list = [1, 2, 2],
|
|
in_channels: int = 16,
|
|
out_channels: Optional[int] = None,
|
|
hidden_size: int = 3072,
|
|
num_attention_heads: int = 24,
|
|
text_dim: int = 4096,
|
|
mlp_width_ratio: float = 4.0,
|
|
num_layers: int = 20,
|
|
rope_dim_list: list = [16, 56, 56],
|
|
rope_type: str = "rope",
|
|
theta: int = 256,
|
|
image_model=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
self.dtype = dtype
|
|
self.out_channels = out_channels or in_channels
|
|
self.patch_size = list(patch_size)
|
|
self.hidden_size = hidden_size
|
|
self.num_attention_heads = num_attention_heads
|
|
self.rope_dim_list = list(rope_dim_list)
|
|
self.rope_type = rope_type
|
|
self.theta = theta
|
|
|
|
if hidden_size % num_attention_heads != 0:
|
|
raise ValueError(
|
|
f"hidden_size ({hidden_size}) must be divisible by num_attention_heads ({num_attention_heads})"
|
|
)
|
|
attention_head_dim = hidden_size // num_attention_heads
|
|
if sum(self.rope_dim_list) != attention_head_dim:
|
|
raise ValueError(
|
|
f"sum(rope_dim_list) ({sum(self.rope_dim_list)}) must equal head_dim ({attention_head_dim})"
|
|
)
|
|
|
|
self.img_in = operations.Conv3d(
|
|
in_channels,
|
|
hidden_size,
|
|
kernel_size=tuple(self.patch_size),
|
|
stride=tuple(self.patch_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
self.condition_embedder = JoyImageTimeTextImageEmbedding(
|
|
dim=hidden_size,
|
|
time_freq_dim=256,
|
|
time_proj_dim=hidden_size * 6,
|
|
text_embed_dim=text_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.double_blocks = nn.ModuleList([
|
|
JoyImageTransformerBlock(
|
|
dim=hidden_size,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
mlp_width_ratio=mlp_width_ratio,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
self.norm_out = FP32LayerNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
|
self.proj_out = operations.Linear(
|
|
hidden_size,
|
|
self.out_channels * math.prod(self.patch_size),
|
|
bias=True,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
def get_rotary_pos_embed(
|
|
self,
|
|
vis_rope_size,
|
|
txt_rope_size: Optional[int] = None,
|
|
device=None,
|
|
):
|
|
target_ndim = 3
|
|
vis_rope_size = list(vis_rope_size)
|
|
if len(vis_rope_size) != target_ndim:
|
|
vis_rope_size = [1] * (target_ndim - len(vis_rope_size)) + vis_rope_size
|
|
|
|
head_dim = self.hidden_size // self.num_attention_heads
|
|
rope_dim_list = self.rope_dim_list
|
|
if rope_dim_list is None:
|
|
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
|
if sum(rope_dim_list) != head_dim:
|
|
raise ValueError("sum(rope_dim_list) should equal head_dim")
|
|
|
|
grid = torch.stack(
|
|
torch.meshgrid(
|
|
*[torch.linspace(0, s, s + 1, dtype=torch.float32, device=device)[:s] for s in vis_rope_size],
|
|
indexing="ij",
|
|
),
|
|
dim=0,
|
|
)
|
|
|
|
vis_cos, vis_sin = [], []
|
|
for i, dim in enumerate(rope_dim_list):
|
|
pos = grid[i].reshape(-1)
|
|
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
|
freqs = torch.outer(pos.float(), freqs)
|
|
vis_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
|
vis_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
|
vis_freqs = (torch.cat(vis_cos, dim=1), torch.cat(vis_sin, dim=1))
|
|
|
|
if txt_rope_size is None:
|
|
return vis_freqs, None
|
|
|
|
grid_txt = torch.arange(txt_rope_size, device=device) + grid.view(-1).max().item() + 1
|
|
txt_cos, txt_sin = [], []
|
|
for i, dim in enumerate(rope_dim_list):
|
|
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device)[: (dim // 2)] / dim))
|
|
freqs = torch.outer(grid_txt.float(), freqs)
|
|
txt_cos.append(freqs.cos().repeat_interleave(2, dim=1))
|
|
txt_sin.append(freqs.sin().repeat_interleave(2, dim=1))
|
|
txt_freqs = (torch.cat(txt_cos, dim=1), torch.cat(txt_sin, dim=1))
|
|
|
|
return vis_freqs, txt_freqs
|
|
|
|
def unpatchify(self, x: torch.Tensor, t: int, h: int, w: int) -> torch.Tensor:
|
|
c = self.out_channels
|
|
pt, ph, pw = self.patch_size
|
|
if t * h * w != x.shape[1]:
|
|
raise ValueError(f"Expected t*h*w ({t * h * w}) to equal x.shape[1] ({x.shape[1]})")
|
|
x = x.reshape(x.shape[0], t, h, w, pt, ph, pw, c)
|
|
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
|
return x.reshape(x.shape[0], c, t * pt, h * ph, w * pw)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
timestep: torch.Tensor,
|
|
encoder_hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
_, _, ot, oh, ow = hidden_states.shape
|
|
tt = ot // self.patch_size[0]
|
|
th = oh // self.patch_size[1]
|
|
tw = ow // self.patch_size[2]
|
|
|
|
img = self.img_in(hidden_states).flatten(2).transpose(1, 2)
|
|
|
|
_, vec, txt = self.condition_embedder(timestep, encoder_hidden_states)
|
|
if vec.shape[-1] > self.hidden_size:
|
|
vec = vec.unflatten(1, (6, -1))
|
|
|
|
txt_seq_len = txt.shape[1]
|
|
|
|
vis_freqs, txt_freqs = self.get_rotary_pos_embed(
|
|
vis_rope_size=[tt, th, tw],
|
|
txt_rope_size=txt_seq_len if self.rope_type == "mrope" else None,
|
|
device=hidden_states.device,
|
|
)
|
|
|
|
for block in self.double_blocks:
|
|
img, txt = block(
|
|
hidden_states=img,
|
|
encoder_hidden_states=txt,
|
|
temb=vec,
|
|
image_rotary_emb=(vis_freqs, txt_freqs),
|
|
)
|
|
|
|
img = self.proj_out(self.norm_out(img))
|
|
img = self.unpatchify(img, tt, th, tw)
|
|
return img
|