appearing functional without rigorous testing

This commit is contained in:
nolan4 2025-10-22 22:20:43 -07:00
parent 560b1bdfca
commit ffe3503370
3 changed files with 708 additions and 18 deletions

View File

@ -2,8 +2,9 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import math
from typing import Optional, Tuple from typing import Optional, Tuple
from einops import repeat from einops import repeat, rearrange
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
@ -11,6 +12,118 @@ from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.patcher_extension import comfy.patcher_extension
class QwenEmbedRope(nn.Module):
"""Research-accurate RoPE implementation for EliGen.
This class matches the research pipeline's QwenEmbedRope exactly.
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
"""
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
self.rope_cache = {}
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
"""
assert dim % 2 == 0
freqs = torch.outer(
index,
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
_, height, width = video_fhw
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
required_len = max_vid_index + max(txt_seq_lens)
cur_max_len = self.pos_freqs.shape[0]
if required_len <= cur_max_len:
return
new_max_len = math.ceil(required_len / 512) * 512
pos_index = torch.arange(new_max_len)
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
return
def forward(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
class GELU(nn.Module): class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -59,6 +172,24 @@ def apply_rotary_emb(x, freqs_cis):
return t_out.reshape(*x.shape) return t_out.reshape(*x.shape)
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: torch.Tensor):
"""
Research-accurate RoPE application for QwenEmbedRope.
Args:
x: Input tensor with shape [b, h, s, d] (batch, heads, sequence, dim)
freqs_cis: Complex frequency tensor with shape [s, features] from QwenEmbedRope
Returns:
Rotated tensor with same shape as input
"""
# x shape: [b, h, s, d]
# freqs_cis shape: [s, features] where features = d (complex numbers)
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x)
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -149,18 +280,89 @@ class Attention(nn.Module):
txt_query = self.norm_added_q(txt_query) txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key) txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1) # Handle both tuple (EliGen) and single tensor (standard) RoPE formats
joint_key = torch.cat([txt_key, img_key], dim=1) if isinstance(image_rotary_emb, tuple):
joint_value = torch.cat([txt_value, img_value], dim=1) # EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
# txt/img query/key are currently [b, s, h, d], need to rearrange to [b, h, s, d]
img_rope, txt_rope = image_rotary_emb
joint_query = apply_rotary_emb(joint_query, image_rotary_emb) # Rearrange to [b, h, s, d] for apply_rotary_emb_qwen
joint_key = apply_rotary_emb(joint_key, image_rotary_emb) txt_query = txt_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
txt_key = txt_key.permute(0, 2, 1, 3)
img_query = img_query.permute(0, 2, 1, 3)
img_key = img_key.permute(0, 2, 1, 3)
joint_query = joint_query.flatten(start_dim=2) # Apply RoPE separately to text and image using research function
joint_key = joint_key.flatten(start_dim=2) txt_query = apply_rotary_emb_qwen(txt_query, txt_rope)
joint_value = joint_value.flatten(start_dim=2) txt_key = apply_rotary_emb_qwen(txt_key, txt_rope)
img_query = apply_rotary_emb_qwen(img_query, img_rope)
img_key = apply_rotary_emb_qwen(img_key, img_rope)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) # Rearrange back to [b, s, h, d]
txt_query = txt_query.permute(0, 2, 1, 3)
txt_key = txt_key.permute(0, 2, 1, 3)
img_query = img_query.permute(0, 2, 1, 3)
img_key = img_key.permute(0, 2, 1, 3)
# Now concatenate
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
else:
# Standard path: Concatenate first, then apply RoPE
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
# Check if we have an EliGen mask - if so, use PyTorch SDPA directly (research-accurate)
has_eligen_mask = False
effective_mask = attention_mask
if transformer_options is not None:
eligen_mask = transformer_options.get("eligen_attention_mask", None)
if eligen_mask is not None:
has_eligen_mask = True
effective_mask = eligen_mask
# Validate shape
expected_seq = joint_query.shape[1]
if eligen_mask.shape[-1] != expected_seq:
raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}")
if has_eligen_mask:
# EliGen path: Use PyTorch SDPA directly (matches research implementation exactly)
# Don't flatten - keep in [b, s, h, d] format for SDPA
# Reshape to [b, h, s, d] for SDPA
joint_query = joint_query.permute(0, 2, 1, 3) # [b, s, h, d] -> [b, h, s, d]
joint_key = joint_key.permute(0, 2, 1, 3)
joint_value = joint_value.permute(0, 2, 1, 3)
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
print(f" - Query shape: {joint_query.shape}")
print(f" - Mask shape: {effective_mask.shape}")
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
# Apply SDPA with mask (research-accurate)
joint_hidden_states = torch.nn.functional.scaled_dot_product_attention(
joint_query, joint_key, joint_value,
attn_mask=effective_mask,
dropout_p=0.0,
is_causal=False
)
# Reshape back: [b, h, s, d] -> [b, s, h*d]
joint_hidden_states = joint_hidden_states.permute(0, 2, 1, 3).flatten(start_dim=2)
else:
# Standard path: Use ComfyUI's optimized attention
joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, effective_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :] txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -310,6 +512,8 @@ class QwenImageTransformer2DModel(nn.Module):
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
# Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
@ -359,6 +563,235 @@ class QwenImageTransformer2DModel(nn.Module):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
entity_prompt_emb_mask, entity_masks, height, width, image):
"""
Process entity masks and build spatial attention mask for EliGen.
This method:
1. Concatenates entity + global prompts
2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder
3. Creates attention mask enforcing spatial restrictions
Args:
latents: [B, 16, H, W]
prompt_emb: [1, seq_len, 3584] - Global prompt
prompt_emb_mask: [1, seq_len]
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
entity_prompt_emb_mask: List[[1, L_i]]
entity_masks: [1, N, 1, H/8, W/8]
height: int
width: int
image: [B, patches, 64] - Patchified latents
Returns:
all_prompt_emb: [1, total_seq, 3584]
image_rotary_emb: RoPE embeddings
attention_mask: [1, 1, total_seq, total_seq]
"""
# SECTION 1: Concatenate entity + global prompts
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope)
# Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying)
img_shapes = [(latents.shape[0], height//16, width//16)]
# Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE)
# Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask]
# Handle None case in ComfyUI (None means no padding, all tokens valid)
if prompt_emb_mask is not None:
global_seq_len = int(prompt_emb_mask.sum(dim=1).item())
else:
# No mask = no padding, use full sequence length
global_seq_len = int(prompt_emb.shape[1])
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
# RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
txt_seq_lens = [global_seq_len]
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
# Create SEPARATE RoPE embeddings for each entity (EXACTLY like research)
# RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
entity_rotary_emb = []
import os
debug = os.environ.get("ELIGEN_DEBUG")
for i, entity_seq_len in enumerate(entity_seq_lens):
# Pass as list for compatibility with research API
entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
entity_rotary_emb.append(entity_rope)
if debug:
print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}")
if debug:
print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}")
print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE")
# Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research)
# QwenEmbedRope returns 2D tensors with shape [seq_len, features]
# Entity ropes: [entity_seq_len, features]
# Global rope: [global_seq_len, features]
# Concatenate along dim=0 to get [total_seq_len, features]
# RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
# Replace text part of tuple (EXACTLY like research)
# RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# Debug output for RoPE embeddings
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}")
print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}")
print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}")
print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}")
print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}")
print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}")
# SECTION 3: Prepare spatial masks
repeat_dim = latents.shape[1] # 16
max_masks = entity_masks.shape[1] # N entities
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
# Pad masks to match padded latent dimensions (same as process_img does)
# entity_masks shape: [1, N, 16, H/8, W/8]
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
padded_h = height // 8
padded_w = width // 8
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
# Validate masks aren't larger than expected (would cause negative padding)
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
# Pad each entity mask
pad_h = padded_h - entity_masks.shape[3]
pad_w = padded_w - entity_masks.shape[4]
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# Add global mask (all True) - must be same size as padded entity masks
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask]
# SECTION 4: Patchify masks
N = len(entity_masks)
batch_size = int(entity_masks[0].shape[0])
seq_lens = entity_seq_lens + [global_seq_len]
total_seq_len = int(sum(seq_lens) + image.shape[1])
# Debug: Check mask dimensions
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}")
print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}")
print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}")
patched_masks = []
for i in range(N):
patched_mask = rearrange(
entity_masks[i],
"B C (H P) (W Q) -> B (H W) (C P Q)",
H=height//16, W=width//16, P=2, Q=2
)
patched_masks.append(patched_mask)
# SECTION 5: Build attention mask matrix
attention_mask = torch.ones(
(batch_size, total_seq_len, total_seq_len),
dtype=torch.bool
).to(device=entity_masks[0].device)
# Calculate positions
image_start = int(sum(seq_lens))
image_end = int(total_seq_len)
cumsum = [0]
single_image_seq = int(image_end - image_start)
for length in seq_lens:
cumsum.append(cumsum[-1] + length)
# RULE 1: Spatial restriction (prompt <-> image)
for i in range(N):
prompt_start = cumsum[i]
prompt_end = cumsum[i+1]
# Create binary mask for which image patches this entity can attend to
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# Always repeat mask to match image sequence length (matches DiffSynth line 480)
repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time)
# Bidirectional restriction:
# - Entity prompt can only attend to its masked image regions
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# - Image patches can only be updated by prompts that own them
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# RULE 2: Entity isolation
for i in range(N):
for j in range(N):
if i == j:
continue
start_i, end_i = cumsum[i], cumsum[i+1]
start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False
# SECTION 6: Convert to additive bias
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
if debug:
print(f"\n[EliGen Debug Mask Values]")
print(f" Token ranges:")
for i in range(len(seq_lens)):
if i < len(seq_lens) - 1:
print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
else:
print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}")
print(f"\n Checking Entity 0 connections:")
# Entity 0 to itself (should be 0)
e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]]
print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed")
# Entity 0 to Entity 1 (should be -inf)
if len(seq_lens) > 2:
e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]]
print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked")
# Entity 0 to Global (should be -inf)
e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]]
print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked")
# Entity 0 to Image (should be partially blocked based on mask)
e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:]
print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked")
# Image to Entity 0 (should match Entity 0 to Image, transposed)
img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]]
print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed")
# Global to Image (should be fully allowed)
global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:]
print(f"\n Checking Global connections:")
print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed")
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
@ -410,15 +843,82 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) # Extract entity data from kwargs
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
ids = torch.cat((txt_ids, img_ids), dim=1) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) entity_masks = kwargs.get("entity_masks", None)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) # import pdb; pdb.set_trace()
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
import os
if os.environ.get("ELIGEN_DEBUG"):
if entity_prompt_emb is not None:
print(f"[EliGen Debug] Entity data found!")
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
# Check if this is positive or negative conditioning
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
else:
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
# Branch: EliGen vs Standard path
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
# Negative conditioning should use standard path
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
# Note: Use padded dimensions from orig_shape, not original latent dimensions
# orig_shape is from process_img which pads to patch_size
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug] Original latent shape: {x.shape}")
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
latents=x,
prompt_emb=encoder_hidden_states,
prompt_emb_mask=encoder_hidden_states_mask,
entity_prompt_emb=entity_prompt_emb,
entity_prompt_emb_mask=entity_prompt_emb_mask,
entity_masks=entity_masks,
height=height,
width=width,
image=hidden_states
)
# Apply image projection (text already processed in process_entity_masks)
hidden_states = self.img_in(hidden_states)
# Store attention mask in transformer_options for the attention layers
if transformer_options is None:
transformer_options = {}
transformer_options["eligen_attention_mask"] = eligen_attention_mask
# Clean up
del img_ids
else:
# Standard path - existing code
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None: if guidance is not None:
guidance = guidance * 1000 guidance = guidance * 1000

View File

@ -1462,6 +1462,20 @@ class QwenImage(BaseModel):
ref_latents_method = kwargs.get("reference_latents_method", None) ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None: if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
# Handle EliGen entity data
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
if entity_prompt_emb is not None:
out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
if entity_prompt_emb_mask is not None:
out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node
entity_masks = kwargs.get("entity_masks", None)
if entity_masks is not None:
out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node
return out return out
def extra_conds_shapes(self, **kwargs): def extra_conds_shapes(self, **kwargs):

View File

@ -1,6 +1,8 @@
import node_helpers import node_helpers
import comfy.utils import comfy.utils
import comfy.conds
import math import math
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
@ -104,12 +106,186 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
return io.NodeOutput(conditioning) return io.NodeOutput(conditioning)
class TextEncodeQwenImageEliGen(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEliGen",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("global_conditioning"),
io.Latent.Input("latent"),
io.Image.Input("entity_mask_1", optional=True),
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_2", optional=True),
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_3", optional=True),
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None,
entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput:
# Extract dimensions from latent tensor
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
latent_samples = latent["samples"]
unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
# The model pads latents to be multiples of patch_size (2 for Qwen)
patch_size = 2
pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size
pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
height = latent_height * 8 # Convert to pixel space for logging
width = latent_width * 8
if pad_h > 0 or pad_w > 0:
print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width}{latent_height}x{latent_width}")
print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
# Collect entity prompts and masks
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3]
entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3]
# Filter out entities with empty prompts or missing masks
valid_entities = []
for prompt, mask in zip(entity_prompts, entity_masks_raw):
if prompt.strip() and mask is not None:
valid_entities.append((prompt, mask))
# Log warning if some entities were skipped
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
if len(valid_entities) < total_prompts_provided:
print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
# If no valid entities, return standard conditioning
if len(valid_entities) == 0:
return io.NodeOutput(global_conditioning)
# Encode each entity prompt separately
entity_prompt_emb_list = []
entity_prompt_emb_mask_list = []
for entity_prompt, _ in valid_entities:
entity_tokens = clip.tokenize(entity_prompt)
entity_cond = clip.encode_from_tokens_scheduled(entity_tokens)
# Extract embeddings and masks from conditioning
# Conditioning format: [[cond_tensor, extra_dict], ...]
entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584]
extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.)
# Extract attention mask from metadata dict
entity_prompt_emb_mask = extra_dict.get("attention_mask", None)
# If no attention mask in extra_dict, create one (all True)
if entity_prompt_emb_mask is None:
seq_len = entity_prompt_emb.shape[1]
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
dtype=torch.bool, device=entity_prompt_emb.device)
entity_prompt_emb_list.append(entity_prompt_emb)
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
# Process spatial masks to latent space
processed_masks = []
for i, (_, mask) in enumerate(valid_entities):
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
mask_tensor = mask
# Log original mask dimensions
original_shape = mask_tensor.shape
if len(original_shape) == 3:
orig_h, orig_w = original_shape[0], original_shape[1]
elif len(original_shape) == 4:
orig_h, orig_w = original_shape[1], original_shape[2]
else:
orig_h, orig_w = original_shape[-2], original_shape[-1]
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
# Ensure mask is in [batch, channels, height, width] format for upscale
if len(mask_tensor.shape) == 3:
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
mask_tensor = mask_tensor.unsqueeze(0)
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
# [batch, height, width, channels] -> [batch, channels, height, width]
mask_tensor = mask_tensor.movedim(-1, 1)
# Take only first channel if multiple channels
if mask_tensor.shape[1] > 1:
mask_tensor = mask_tensor[:, 0:1, :, :]
# Resize to latent space dimensions using nearest neighbor
resized_mask = comfy.utils.common_upscale(
mask_tensor,
latent_width,
latent_height,
upscale_method="nearest-exact",
crop="disabled"
)
# Threshold to binary (0 or 1)
# Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling
resized_mask = (resized_mask > 0).float()
# Log how many pixels are active in the mask
active_pixels = (resized_mask > 0).sum().item()
total_pixels = resized_mask.numel()
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
processed_masks.append(resized_mask)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
# No padding - handle dynamic number of entities
entity_masks_tensor = torch.stack(processed_masks, dim=1)
# Extract global prompt embedding and mask from conditioning
# Conditioning format: [[cond_tensor, extra_dict]]
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
global_extra_dict = global_conditioning[0][1] # Metadata dict
global_prompt_emb_mask = global_extra_dict.get("attention_mask", None)
# If no attention mask, create one (all True)
if global_prompt_emb_mask is None:
global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]),
dtype=torch.bool, device=global_prompt_emb.device)
# Attach entity data to conditioning using conditioning_set_values
# Wrap lists in CONDList so they can be properly concatenated during CFG
entity_data = {
"entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list),
"entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list),
"entity_masks": comfy.conds.CONDRegular(entity_masks_tensor),
}
conditioning_with_entities = node_helpers.conditioning_set_values(
global_conditioning,
entity_data,
append=True
)
return io.NodeOutput(conditioning_with_entities)
class QwenExtension(ComfyExtension): class QwenExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
TextEncodeQwenImageEdit, TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus, TextEncodeQwenImageEditPlus,
TextEncodeQwenImageEliGen,
] ]