mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
appearing functional without rigorous testing
This commit is contained in:
parent
560b1bdfca
commit
ffe3503370
@ -2,8 +2,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from einops import repeat
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||||
@ -11,6 +12,118 @@ from comfy.ldm.flux.layers import EmbedND
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEmbedRope(nn.Module):
|
||||||
|
"""Research-accurate RoPE implementation for EliGen.
|
||||||
|
|
||||||
|
This class matches the research pipeline's QwenEmbedRope exactly.
|
||||||
|
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
|
||||||
|
"""
|
||||||
|
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
pos_index = torch.arange(4096)
|
||||||
|
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat([
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.neg_freqs = torch.cat([
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.rope_cache = {}
|
||||||
|
self.scale_rope = scale_rope
|
||||||
|
|
||||||
|
def rope_params(self, index, dim, theta=10000):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||||
|
"""
|
||||||
|
assert dim % 2 == 0
|
||||||
|
freqs = torch.outer(
|
||||||
|
index,
|
||||||
|
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||||
|
)
|
||||||
|
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||||
|
if isinstance(video_fhw, list):
|
||||||
|
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||||
|
_, height, width = video_fhw
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width)
|
||||||
|
required_len = max_vid_index + max(txt_seq_lens)
|
||||||
|
cur_max_len = self.pos_freqs.shape[0]
|
||||||
|
if required_len <= cur_max_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
new_max_len = math.ceil(required_len / 512) * 512
|
||||||
|
pos_index = torch.arange(new_max_len)
|
||||||
|
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||||
|
self.pos_freqs = torch.cat([
|
||||||
|
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
self.neg_freqs = torch.cat([
|
||||||
|
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||||
|
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||||
|
], dim=1)
|
||||||
|
return
|
||||||
|
|
||||||
|
def forward(self, video_fhw, txt_seq_lens, device):
|
||||||
|
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||||
|
if self.pos_freqs.device != device:
|
||||||
|
self.pos_freqs = self.pos_freqs.to(device)
|
||||||
|
self.neg_freqs = self.neg_freqs.to(device)
|
||||||
|
|
||||||
|
vid_freqs = []
|
||||||
|
max_vid_index = 0
|
||||||
|
for idx, fhw in enumerate(video_fhw):
|
||||||
|
frame, height, width = fhw
|
||||||
|
rope_key = f"{idx}_{height}_{width}"
|
||||||
|
|
||||||
|
if rope_key not in self.rope_cache:
|
||||||
|
seq_lens = frame * height * width
|
||||||
|
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||||
|
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||||
|
if self.scale_rope:
|
||||||
|
freqs_height = torch.cat(
|
||||||
|
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||||
|
)
|
||||||
|
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||||
|
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||||
|
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||||
|
|
||||||
|
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||||
|
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||||
|
vid_freqs.append(self.rope_cache[rope_key])
|
||||||
|
|
||||||
|
if self.scale_rope:
|
||||||
|
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||||
|
else:
|
||||||
|
max_vid_index = max(height, width, max_vid_index)
|
||||||
|
|
||||||
|
max_len = max(txt_seq_lens)
|
||||||
|
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||||
|
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||||
|
|
||||||
|
return vid_freqs, txt_freqs
|
||||||
|
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -59,6 +172,24 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
return t_out.reshape(*x.shape)
|
return t_out.reshape(*x.shape)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Research-accurate RoPE application for QwenEmbedRope.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim)
|
||||||
|
freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rotated tensor with same shape as input
|
||||||
|
"""
|
||||||
|
# x shape: [b, h, s, d]
|
||||||
|
# freqs_cis shape: [s, features] where features = d (complex numbers)
|
||||||
|
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||||
|
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||||
|
return x_out.type_as(x)
|
||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -149,18 +280,89 @@ class Attention(nn.Module):
|
|||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
if isinstance(image_rotary_emb, tuple):
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
||||||
|
# txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d]
|
||||||
|
img_rope, txt_rope = image_rotary_emb
|
||||||
|
|
||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
# Rearrange to [b, h, s, d] for apply_rotary_emb_qwen
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
|
||||||
|
txt_key = txt_key.permute(0, 2, 1, 3)
|
||||||
|
img_query = img_query.permute(0, 2, 1, 3)
|
||||||
|
img_key = img_key.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
# Apply RoPE separately to text and image using research function
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
txt_query = apply_rotary_emb_qwen(txt_query, txt_rope)
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
txt_key = apply_rotary_emb_qwen(txt_key, txt_rope)
|
||||||
|
img_query = apply_rotary_emb_qwen(img_query, img_rope)
|
||||||
|
img_key = apply_rotary_emb_qwen(img_key, img_rope)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
# Rearrange back to [b, s, h, d]
|
||||||
|
txt_query = txt_query.permute(0, 2, 1, 3)
|
||||||
|
txt_key = txt_key.permute(0, 2, 1, 3)
|
||||||
|
img_query = img_query.permute(0, 2, 1, 3)
|
||||||
|
img_key = img_key.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Now concatenate
|
||||||
|
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||||
|
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||||
|
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||||
|
else:
|
||||||
|
# Standard path: Concatenate first, then apply RoPE
|
||||||
|
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||||
|
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||||
|
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||||
|
|
||||||
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
|
# Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate)
|
||||||
|
has_eligen_mask = False
|
||||||
|
effective_mask = attention_mask
|
||||||
|
if transformer_options is not None:
|
||||||
|
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
||||||
|
if eligen_mask is not None:
|
||||||
|
has_eligen_mask = True
|
||||||
|
effective_mask = eligen_mask
|
||||||
|
|
||||||
|
# Validate shape
|
||||||
|
expected_seq = joint_query.shape[1]
|
||||||
|
if eligen_mask.shape[-1] != expected_seq:
|
||||||
|
raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}")
|
||||||
|
|
||||||
|
if has_eligen_mask:
|
||||||
|
# EliGen path: Use PyTorch SDPA directly (matches research implementation exactly)
|
||||||
|
# Don't flatten - keep in [b, s, h, d] format for SDPA
|
||||||
|
# Reshape to [b, h, s, d] for SDPA
|
||||||
|
joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
|
||||||
|
joint_key = joint_key.permute(0, 2, 1, 3)
|
||||||
|
joint_value = joint_value.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
import os
|
||||||
|
if os.environ.get("ELIGEN_DEBUG"):
|
||||||
|
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
|
||||||
|
print(f" - Query shape: {joint_query.shape}")
|
||||||
|
print(f" - Mask shape: {effective_mask.shape}")
|
||||||
|
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
|
||||||
|
|
||||||
|
# Apply SDPA with mask (research-accurate)
|
||||||
|
joint_hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
joint_query, joint_key, joint_value,
|
||||||
|
attn_mask=effective_mask,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape back: [b, h, s, d] -> [b, s, h*d]
|
||||||
|
joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2)
|
||||||
|
else:
|
||||||
|
# Standard path: Use ComfyUI's optimized attention
|
||||||
|
joint_query = joint_query.flatten(start_dim=2)
|
||||||
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
|
|
||||||
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -310,6 +512,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
# Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs)
|
||||||
|
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
|
||||||
|
|
||||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim,
|
embedding_dim=self.inner_dim,
|
||||||
@ -359,6 +563,235 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
|
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
||||||
|
entity_prompt_emb_mask, entity_masks, height, width, image):
|
||||||
|
"""
|
||||||
|
Process entity masks and build spatial attention mask for EliGen.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
1. Concatenates entity + global prompts
|
||||||
|
2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder
|
||||||
|
3. Creates attention mask enforcing spatial restrictions
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: [B, 16, H, W]
|
||||||
|
prompt_emb: [1, seq_len, 3584] - Global prompt
|
||||||
|
prompt_emb_mask: [1, seq_len]
|
||||||
|
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
|
||||||
|
entity_prompt_emb_mask: List[[1, L_i]]
|
||||||
|
entity_masks: [1, N, 1, H/8, W/8]
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
image: [B, patches, 64] - Patchified latents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
all_prompt_emb: [1, total_seq, 3584]
|
||||||
|
image_rotary_emb: RoPE embeddings
|
||||||
|
attention_mask: [1, 1, total_seq, total_seq]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# SECTION 1: Concatenate entity + global prompts
|
||||||
|
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||||
|
all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb]
|
||||||
|
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||||
|
|
||||||
|
# SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope)
|
||||||
|
# Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying)
|
||||||
|
img_shapes = [(latents.shape[0], height//16, width//16)]
|
||||||
|
|
||||||
|
# Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE)
|
||||||
|
# Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||||
|
entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask]
|
||||||
|
|
||||||
|
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
||||||
|
if prompt_emb_mask is not None:
|
||||||
|
global_seq_len = int(prompt_emb_mask.sum(dim=1).item())
|
||||||
|
else:
|
||||||
|
# No mask = no padding, use full sequence length
|
||||||
|
global_seq_len = int(prompt_emb.shape[1])
|
||||||
|
|
||||||
|
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
||||||
|
# RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
|
txt_seq_lens = [global_seq_len]
|
||||||
|
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||||
|
|
||||||
|
# Create SEPARATE RoPE embeddings for each entity (EXACTLY like research)
|
||||||
|
# RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||||
|
entity_rotary_emb = []
|
||||||
|
|
||||||
|
import os
|
||||||
|
debug = os.environ.get("ELIGEN_DEBUG")
|
||||||
|
|
||||||
|
for i, entity_seq_len in enumerate(entity_seq_lens):
|
||||||
|
# Pass as list for compatibility with research API
|
||||||
|
entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
|
||||||
|
entity_rotary_emb.append(entity_rope)
|
||||||
|
if debug:
|
||||||
|
print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}")
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}")
|
||||||
|
print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE")
|
||||||
|
|
||||||
|
# Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research)
|
||||||
|
# QwenEmbedRope returns 2D tensors with shape [seq_len, features]
|
||||||
|
# Entity ropes: [entity_seq_len, features]
|
||||||
|
# Global rope: [global_seq_len, features]
|
||||||
|
# Concatenate along dim=0 to get [total_seq_len, features]
|
||||||
|
# RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||||
|
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||||
|
|
||||||
|
# Replace text part of tuple (EXACTLY like research)
|
||||||
|
# RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||||
|
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||||
|
|
||||||
|
# Debug output for RoPE embeddings
|
||||||
|
import os
|
||||||
|
if os.environ.get("ELIGEN_DEBUG"):
|
||||||
|
print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}")
|
||||||
|
print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}")
|
||||||
|
print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}")
|
||||||
|
print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}")
|
||||||
|
print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}")
|
||||||
|
print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}")
|
||||||
|
|
||||||
|
# SECTION 3: Prepare spatial masks
|
||||||
|
repeat_dim = latents.shape[1] # 16
|
||||||
|
max_masks = entity_masks.shape[1] # N entities
|
||||||
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
|
||||||
|
# Pad masks to match padded latent dimensions (same as process_img does)
|
||||||
|
# entity_masks shape: [1, N, 16, H/8, W/8]
|
||||||
|
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
||||||
|
padded_h = height // 8
|
||||||
|
padded_w = width // 8
|
||||||
|
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||||
|
# Validate masks aren't larger than expected (would cause negative padding)
|
||||||
|
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
||||||
|
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
||||||
|
|
||||||
|
# Pad each entity mask
|
||||||
|
pad_h = padded_h - entity_masks.shape[3]
|
||||||
|
pad_w = padded_w - entity_masks.shape[4]
|
||||||
|
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||||
|
|
||||||
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
|
||||||
|
# Add global mask (all True) - must be same size as padded entity masks
|
||||||
|
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
|
||||||
|
device=latents.device, dtype=latents.dtype)
|
||||||
|
entity_masks = entity_masks + [global_mask]
|
||||||
|
|
||||||
|
# SECTION 4: Patchify masks
|
||||||
|
N = len(entity_masks)
|
||||||
|
batch_size = int(entity_masks[0].shape[0])
|
||||||
|
seq_lens = entity_seq_lens + [global_seq_len]
|
||||||
|
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||||
|
|
||||||
|
# Debug: Check mask dimensions
|
||||||
|
import os
|
||||||
|
if os.environ.get("ELIGEN_DEBUG"):
|
||||||
|
print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}")
|
||||||
|
print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}")
|
||||||
|
print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}")
|
||||||
|
|
||||||
|
patched_masks = []
|
||||||
|
for i in range(N):
|
||||||
|
patched_mask = rearrange(
|
||||||
|
entity_masks[i],
|
||||||
|
"B C (H P) (W Q) -> B (H W) (C P Q)",
|
||||||
|
H=height//16, W=width//16, P=2, Q=2
|
||||||
|
)
|
||||||
|
patched_masks.append(patched_mask)
|
||||||
|
|
||||||
|
# SECTION 5: Build attention mask matrix
|
||||||
|
attention_mask = torch.ones(
|
||||||
|
(batch_size, total_seq_len, total_seq_len),
|
||||||
|
dtype=torch.bool
|
||||||
|
).to(device=entity_masks[0].device)
|
||||||
|
|
||||||
|
# Calculate positions
|
||||||
|
image_start = int(sum(seq_lens))
|
||||||
|
image_end = int(total_seq_len)
|
||||||
|
cumsum = [0]
|
||||||
|
single_image_seq = int(image_end - image_start)
|
||||||
|
|
||||||
|
for length in seq_lens:
|
||||||
|
cumsum.append(cumsum[-1] + length)
|
||||||
|
|
||||||
|
# RULE 1: Spatial restriction (prompt <-> image)
|
||||||
|
for i in range(N):
|
||||||
|
prompt_start = cumsum[i]
|
||||||
|
prompt_end = cumsum[i+1]
|
||||||
|
|
||||||
|
# Create binary mask for which image patches this entity can attend to
|
||||||
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
|
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||||
|
|
||||||
|
# Always repeat mask to match image sequence length (matches DiffSynth line 480)
|
||||||
|
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||||
|
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||||
|
|
||||||
|
# Bidirectional restriction:
|
||||||
|
# - Entity prompt can only attend to its masked image regions
|
||||||
|
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
|
||||||
|
# - Image patches can only be updated by prompts that own them
|
||||||
|
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
|
||||||
|
|
||||||
|
# RULE 2: Entity isolation
|
||||||
|
for i in range(N):
|
||||||
|
for j in range(N):
|
||||||
|
if i == j:
|
||||||
|
continue
|
||||||
|
start_i, end_i = cumsum[i], cumsum[i+1]
|
||||||
|
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||||
|
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||||
|
|
||||||
|
# SECTION 6: Convert to additive bias
|
||||||
|
attention_mask = attention_mask.float()
|
||||||
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
|
attention_mask[attention_mask == 1] = 0
|
||||||
|
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
if debug:
|
||||||
|
print(f"\n[EliGen Debug Mask Values]")
|
||||||
|
print(f" Token ranges:")
|
||||||
|
for i in range(len(seq_lens)):
|
||||||
|
if i < len(seq_lens) - 1:
|
||||||
|
print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
|
||||||
|
else:
|
||||||
|
print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
|
||||||
|
print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}")
|
||||||
|
|
||||||
|
print(f"\n Checking Entity 0 connections:")
|
||||||
|
# Entity 0 to itself (should be 0)
|
||||||
|
e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]]
|
||||||
|
print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed")
|
||||||
|
|
||||||
|
# Entity 0 to Entity 1 (should be -inf)
|
||||||
|
if len(seq_lens) > 2:
|
||||||
|
e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]]
|
||||||
|
print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked")
|
||||||
|
|
||||||
|
# Entity 0 to Global (should be -inf)
|
||||||
|
e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]]
|
||||||
|
print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked")
|
||||||
|
|
||||||
|
# Entity 0 to Image (should be partially blocked based on mask)
|
||||||
|
e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:]
|
||||||
|
print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked")
|
||||||
|
|
||||||
|
# Image to Entity 0 (should match Entity 0 to Image, transposed)
|
||||||
|
img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]]
|
||||||
|
print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed")
|
||||||
|
|
||||||
|
# Global to Image (should be fully allowed)
|
||||||
|
global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:]
|
||||||
|
print(f"\n Checking Global connections:")
|
||||||
|
print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed")
|
||||||
|
|
||||||
|
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
@ -410,15 +843,82 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
# Extract entity data from kwargs
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
entity_masks = kwargs.get("entity_masks", None)
|
||||||
del ids, txt_ids, img_ids
|
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
# import pdb; pdb.set_trace()
|
||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
||||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
||||||
|
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
|
||||||
|
import os
|
||||||
|
if os.environ.get("ELIGEN_DEBUG"):
|
||||||
|
if entity_prompt_emb is not None:
|
||||||
|
print(f"[EliGen Debug] Entity data found!")
|
||||||
|
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
|
||||||
|
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
|
||||||
|
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
|
||||||
|
# Check if this is positive or negative conditioning
|
||||||
|
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||||
|
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
|
||||||
|
else:
|
||||||
|
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
|
||||||
|
|
||||||
|
# Branch: EliGen vs Standard path
|
||||||
|
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
|
||||||
|
# Negative conditioning should use standard path
|
||||||
|
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||||
|
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
|
||||||
|
|
||||||
|
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||||
|
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
|
||||||
|
# Note: Use padded dimensions from orig_shape, not original latent dimensions
|
||||||
|
# orig_shape is from process_img which pads to patch_size
|
||||||
|
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
|
||||||
|
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
|
||||||
|
|
||||||
|
if os.environ.get("ELIGEN_DEBUG"):
|
||||||
|
print(f"[EliGen Debug] Original latent shape: {x.shape}")
|
||||||
|
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
|
||||||
|
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
|
||||||
|
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
|
||||||
|
|
||||||
|
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
|
||||||
|
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||||
|
latents=x,
|
||||||
|
prompt_emb=encoder_hidden_states,
|
||||||
|
prompt_emb_mask=encoder_hidden_states_mask,
|
||||||
|
entity_prompt_emb=entity_prompt_emb,
|
||||||
|
entity_prompt_emb_mask=entity_prompt_emb_mask,
|
||||||
|
entity_masks=entity_masks,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
image=hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply image projection (text already processed in process_entity_masks)
|
||||||
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|
||||||
|
# Store attention mask in transformer_options for the attention layers
|
||||||
|
if transformer_options is None:
|
||||||
|
transformer_options = {}
|
||||||
|
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
del img_ids
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Standard path - existing code
|
||||||
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||||
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
if guidance is not None:
|
if guidance is not None:
|
||||||
guidance = guidance * 1000
|
guidance = guidance * 1000
|
||||||
|
|||||||
@ -1462,6 +1462,20 @@ class QwenImage(BaseModel):
|
|||||||
ref_latents_method = kwargs.get("reference_latents_method", None)
|
ref_latents_method = kwargs.get("reference_latents_method", None)
|
||||||
if ref_latents_method is not None:
|
if ref_latents_method is not None:
|
||||||
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
|
||||||
|
|
||||||
|
# Handle EliGen entity data
|
||||||
|
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||||
|
if entity_prompt_emb is not None:
|
||||||
|
out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node
|
||||||
|
|
||||||
|
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||||
|
if entity_prompt_emb_mask is not None:
|
||||||
|
out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node
|
||||||
|
|
||||||
|
entity_masks = kwargs.get("entity_masks", None)
|
||||||
|
if entity_masks is not None:
|
||||||
|
out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def extra_conds_shapes(self, **kwargs):
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.conds
|
||||||
import math
|
import math
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
@ -104,12 +106,186 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
|||||||
return io.NodeOutput(conditioning)
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|
||||||
|
class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TextEncodeQwenImageEliGen",
|
||||||
|
category="advanced/conditioning",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.Conditioning.Input("global_conditioning"),
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Image.Input("entity_mask_1", optional=True),
|
||||||
|
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
|
||||||
|
io.Image.Input("entity_mask_2", optional=True),
|
||||||
|
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
|
||||||
|
io.Image.Input("entity_mask_3", optional=True),
|
||||||
|
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None,
|
||||||
|
entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput:
|
||||||
|
|
||||||
|
# Extract dimensions from latent tensor
|
||||||
|
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
|
||||||
|
latent_samples = latent["samples"]
|
||||||
|
unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space
|
||||||
|
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
|
||||||
|
|
||||||
|
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
|
||||||
|
# The model pads latents to be multiples of patch_size (2 for Qwen)
|
||||||
|
patch_size = 2
|
||||||
|
pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size
|
||||||
|
pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size
|
||||||
|
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
|
||||||
|
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
|
||||||
|
|
||||||
|
height = latent_height * 8 # Convert to pixel space for logging
|
||||||
|
width = latent_width * 8
|
||||||
|
|
||||||
|
if pad_h > 0 or pad_w > 0:
|
||||||
|
print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}")
|
||||||
|
print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
|
||||||
|
|
||||||
|
# Collect entity prompts and masks
|
||||||
|
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3]
|
||||||
|
entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3]
|
||||||
|
|
||||||
|
# Filter out entities with empty prompts or missing masks
|
||||||
|
valid_entities = []
|
||||||
|
for prompt, mask in zip(entity_prompts, entity_masks_raw):
|
||||||
|
if prompt.strip() and mask is not None:
|
||||||
|
valid_entities.append((prompt, mask))
|
||||||
|
|
||||||
|
# Log warning if some entities were skipped
|
||||||
|
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
|
||||||
|
if len(valid_entities) < total_prompts_provided:
|
||||||
|
print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
|
||||||
|
|
||||||
|
# If no valid entities, return standard conditioning
|
||||||
|
if len(valid_entities) == 0:
|
||||||
|
return io.NodeOutput(global_conditioning)
|
||||||
|
|
||||||
|
# Encode each entity prompt separately
|
||||||
|
entity_prompt_emb_list = []
|
||||||
|
entity_prompt_emb_mask_list = []
|
||||||
|
|
||||||
|
for entity_prompt, _ in valid_entities:
|
||||||
|
entity_tokens = clip.tokenize(entity_prompt)
|
||||||
|
entity_cond = clip.encode_from_tokens_scheduled(entity_tokens)
|
||||||
|
|
||||||
|
# Extract embeddings and masks from conditioning
|
||||||
|
# Conditioning format: [[cond_tensor, extra_dict], ...]
|
||||||
|
entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584]
|
||||||
|
extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.)
|
||||||
|
|
||||||
|
# Extract attention mask from metadata dict
|
||||||
|
entity_prompt_emb_mask = extra_dict.get("attention_mask", None)
|
||||||
|
|
||||||
|
# If no attention mask in extra_dict, create one (all True)
|
||||||
|
if entity_prompt_emb_mask is None:
|
||||||
|
seq_len = entity_prompt_emb.shape[1]
|
||||||
|
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
|
||||||
|
dtype=torch.bool, device=entity_prompt_emb.device)
|
||||||
|
|
||||||
|
entity_prompt_emb_list.append(entity_prompt_emb)
|
||||||
|
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
|
||||||
|
|
||||||
|
# Process spatial masks to latent space
|
||||||
|
processed_masks = []
|
||||||
|
for i, (_, mask) in enumerate(valid_entities):
|
||||||
|
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
|
||||||
|
mask_tensor = mask
|
||||||
|
|
||||||
|
# Log original mask dimensions
|
||||||
|
original_shape = mask_tensor.shape
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
orig_h, orig_w = original_shape[0], original_shape[1]
|
||||||
|
elif len(original_shape) == 4:
|
||||||
|
orig_h, orig_w = original_shape[1], original_shape[2]
|
||||||
|
else:
|
||||||
|
orig_h, orig_w = original_shape[-2], original_shape[-1]
|
||||||
|
|
||||||
|
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
|
||||||
|
|
||||||
|
# Ensure mask is in [batch, channels, height, width] format for upscale
|
||||||
|
if len(mask_tensor.shape) == 3:
|
||||||
|
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
|
||||||
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||||||
|
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
|
||||||
|
# [batch, height, width, channels] -> [batch, channels, height, width]
|
||||||
|
mask_tensor = mask_tensor.movedim(-1, 1)
|
||||||
|
|
||||||
|
# Take only first channel if multiple channels
|
||||||
|
if mask_tensor.shape[1] > 1:
|
||||||
|
mask_tensor = mask_tensor[:, 0:1, :, :]
|
||||||
|
|
||||||
|
# Resize to latent space dimensions using nearest neighbor
|
||||||
|
resized_mask = comfy.utils.common_upscale(
|
||||||
|
mask_tensor,
|
||||||
|
latent_width,
|
||||||
|
latent_height,
|
||||||
|
upscale_method="nearest-exact",
|
||||||
|
crop="disabled"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Threshold to binary (0 or 1)
|
||||||
|
# Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling
|
||||||
|
resized_mask = (resized_mask > 0).float()
|
||||||
|
|
||||||
|
# Log how many pixels are active in the mask
|
||||||
|
active_pixels = (resized_mask > 0).sum().item()
|
||||||
|
total_pixels = resized_mask.numel()
|
||||||
|
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
|
||||||
|
|
||||||
|
processed_masks.append(resized_mask)
|
||||||
|
|
||||||
|
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||||
|
# No padding - handle dynamic number of entities
|
||||||
|
entity_masks_tensor = torch.stack(processed_masks, dim=1)
|
||||||
|
|
||||||
|
# Extract global prompt embedding and mask from conditioning
|
||||||
|
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||||
|
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
|
||||||
|
global_extra_dict = global_conditioning[0][1] # Metadata dict
|
||||||
|
|
||||||
|
global_prompt_emb_mask = global_extra_dict.get("attention_mask", None)
|
||||||
|
|
||||||
|
# If no attention mask, create one (all True)
|
||||||
|
if global_prompt_emb_mask is None:
|
||||||
|
global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]),
|
||||||
|
dtype=torch.bool, device=global_prompt_emb.device)
|
||||||
|
|
||||||
|
# Attach entity data to conditioning using conditioning_set_values
|
||||||
|
# Wrap lists in CONDList so they can be properly concatenated during CFG
|
||||||
|
entity_data = {
|
||||||
|
"entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list),
|
||||||
|
"entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list),
|
||||||
|
"entity_masks": comfy.conds.CONDRegular(entity_masks_tensor),
|
||||||
|
}
|
||||||
|
|
||||||
|
conditioning_with_entities = node_helpers.conditioning_set_values(
|
||||||
|
global_conditioning,
|
||||||
|
entity_data,
|
||||||
|
append=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return io.NodeOutput(conditioning_with_entities)
|
||||||
|
|
||||||
|
|
||||||
class QwenExtension(ComfyExtension):
|
class QwenExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
TextEncodeQwenImageEdit,
|
TextEncodeQwenImageEdit,
|
||||||
TextEncodeQwenImageEditPlus,
|
TextEncodeQwenImageEditPlus,
|
||||||
|
TextEncodeQwenImageEliGen,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user