fixed application of entity-specific RoPE embeddings

This commit is contained in:
nolan4 2025-10-23 22:56:32 -07:00
parent ffe3503370
commit 1d9124203f
3 changed files with 47 additions and 145 deletions

View File

@ -339,8 +339,6 @@ class Attention(nn.Module):
joint_key = joint_key.permute(0, 2, 1, 3)
joint_value = joint_value.permute(0, 2, 1, 3)
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
print(f" - Query shape: {joint_query.shape}")
print(f" - Mask shape: {effective_mask.shape}")
@ -592,15 +590,14 @@ class QwenImageTransformer2DModel(nn.Module):
# SECTION 1: Concatenate entity + global prompts
all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope)
# SECTION 2: Build RoPE position embeddings
# Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying)
img_shapes = [(latents.shape[0], height//16, width//16)]
# Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE)
# Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
# Calculate sequence lengths for entities and global prompt
entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask]
# Handle None case in ComfyUI (None means no padding, all tokens valid)
@ -611,56 +608,27 @@ class QwenImageTransformer2DModel(nn.Module):
global_seq_len = int(prompt_emb.shape[1])
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
# RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
txt_seq_lens = [global_seq_len]
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
# Create SEPARATE RoPE embeddings for each entity (EXACTLY like research)
# RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
entity_rotary_emb = []
# Create SEPARATE RoPE embeddings for each entity
# Each entity gets its own positional encoding based on its sequence length
entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
for entity_seq_len in entity_seq_lens]
import os
debug = os.environ.get("ELIGEN_DEBUG")
for i, entity_seq_len in enumerate(entity_seq_lens):
# Pass as list for compatibility with research API
entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
entity_rotary_emb.append(entity_rope)
if debug:
print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}")
if debug:
print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}")
print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE")
# Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research)
# QwenEmbedRope returns 2D tensors with shape [seq_len, features]
# Entity ropes: [entity_seq_len, features]
# Global rope: [global_seq_len, features]
# Concatenate along dim=0 to get [total_seq_len, features]
# RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
# Concatenate entity RoPEs with global RoPE along sequence dimension
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
# Replace text part of tuple (EXACTLY like research)
# RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# Replace text part of tuple with concatenated entity + global RoPE
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# Debug output for RoPE embeddings
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}")
print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}")
print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}")
print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}")
print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}")
print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}")
# SECTION 3: Prepare spatial masks
repeat_dim = latents.shape[1] # 16
max_masks = entity_masks.shape[1] # N entities
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
# Pad masks to match padded latent dimensions (same as process_img does)
# Pad masks to match padded latent dimensions
# entity_masks shape: [1, N, 16, H/8, W/8]
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
padded_h = height // 8
@ -688,13 +656,6 @@ class QwenImageTransformer2DModel(nn.Module):
seq_lens = entity_seq_lens + [global_seq_len]
total_seq_len = int(sum(seq_lens) + image.shape[1])
# Debug: Check mask dimensions
import os
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}")
print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}")
print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}")
patched_masks = []
for i in range(N):
patched_mask = rearrange(
@ -753,43 +714,6 @@ class QwenImageTransformer2DModel(nn.Module):
attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
if debug:
print(f"\n[EliGen Debug Mask Values]")
print(f" Token ranges:")
for i in range(len(seq_lens)):
if i < len(seq_lens) - 1:
print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
else:
print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}")
print(f"\n Checking Entity 0 connections:")
# Entity 0 to itself (should be 0)
e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]]
print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed")
# Entity 0 to Entity 1 (should be -inf)
if len(seq_lens) > 2:
e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]]
print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked")
# Entity 0 to Global (should be -inf)
e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]]
print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked")
# Entity 0 to Image (should be partially blocked based on mask)
e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:]
print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked")
# Image to Entity 0 (should match Entity 0 to Image, transposed)
img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]]
print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed")
# Global to Image (should be fully allowed)
global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:]
print(f"\n Checking Global connections:")
print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed")
return all_prompt_emb, image_rotary_emb, attention_mask
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
@ -848,12 +772,7 @@ class QwenImageTransformer2DModel(nn.Module):
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
entity_masks = kwargs.get("entity_masks", None)
# import pdb; pdb.set_trace()
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
import os
if os.environ.get("ELIGEN_DEBUG"):
if entity_prompt_emb is not None:
print(f"[EliGen Debug] Entity data found!")
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
@ -878,7 +797,6 @@ class QwenImageTransformer2DModel(nn.Module):
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
if os.environ.get("ELIGEN_DEBUG"):
print(f"[EliGen Debug] Original latent shape: {x.shape}")
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")

View File

@ -119,7 +119,6 @@ def convert_tensor(extra, dtype, device):
extra = comfy.model_management.cast_to_device(extra, device, None)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@ -381,7 +380,6 @@ class BaseModel(torch.nn.Module):
def extra_conds_shapes(self, **kwargs):
return {}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
weights = []
@ -477,7 +475,6 @@ class SDXL(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
@ -554,7 +551,6 @@ class SV3D_p(SVD_img2vid):
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
return torch.cat(out, dim=1)
class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device)
@ -638,13 +634,11 @@ class IP2P:
image = utils.resize_to_batch_size(image, noise.shape[0])
return self.process_ip2p_image_in(image)
class SD15_instructpix2pix(IP2P, BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
self.process_ip2p_image_in = lambda image: image
class SDXL_instructpix2pix(IP2P, SDXL):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device)
@ -694,7 +688,6 @@ class StableCascade_C(BaseModel):
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class StableCascade_B(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
super().__init__(model_config, model_type, device=device, unet_model=StageB)
@ -714,7 +707,6 @@ class StableCascade_B(BaseModel):
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
return out
class SD3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
@ -729,7 +721,6 @@ class SD3(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)
@ -741,7 +732,6 @@ class AuraFlow(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class StableAudio1(BaseModel):
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
@ -780,7 +770,6 @@ class StableAudio1(BaseModel):
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
@ -914,7 +903,6 @@ class Flux(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@ -1166,7 +1154,6 @@ class WAN21(BaseModel):
return out
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
@ -1466,15 +1453,17 @@ class QwenImage(BaseModel):
# Handle EliGen entity data
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
if entity_prompt_emb is not None:
out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node
out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb)
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
if entity_prompt_emb_mask is not None:
out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node
out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask)
entity_masks = kwargs.get("entity_masks", None)
if entity_masks is not None:
out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
# import pdb; pdb.set_trace()
return out

View File

@ -176,17 +176,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
entity_prompt_emb_list = []
entity_prompt_emb_mask_list = []
for entity_prompt, _ in valid_entities:
for entity_prompt, _ in valid_entities: # mask not used at this point
entity_tokens = clip.tokenize(entity_prompt)
entity_cond = clip.encode_from_tokens_scheduled(entity_tokens)
entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True)
# Extract embeddings and masks from conditioning
# Conditioning format: [[cond_tensor, extra_dict], ...]
entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584]
extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.)
# Extract attention mask from metadata dict
entity_prompt_emb_mask = extra_dict.get("attention_mask", None)
entity_prompt_emb = entity_cond_dict["cond"]
entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None)
# If no attention mask in extra_dict, create one (all True)
if entity_prompt_emb_mask is None:
@ -194,11 +189,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
dtype=torch.bool, device=entity_prompt_emb.device)
entity_prompt_emb_list.append(entity_prompt_emb)
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
# Process spatial masks to latent space
processed_masks = []
processed_entity_masks = []
for i, (_, mask) in enumerate(valid_entities):
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
mask_tensor = mask
@ -244,11 +240,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
total_pixels = resized_mask.numel()
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
processed_masks.append(resized_mask)
processed_entity_masks.append(resized_mask)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
# Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel)
# No padding - handle dynamic number of entities
entity_masks_tensor = torch.stack(processed_masks, dim=1)
entity_masks_tensor = torch.stack(processed_entity_masks, dim=1)
# Extract global prompt embedding and mask from conditioning
# Conditioning format: [[cond_tensor, extra_dict]]
@ -263,11 +259,10 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
dtype=torch.bool, device=global_prompt_emb.device)
# Attach entity data to conditioning using conditioning_set_values
# Wrap lists in CONDList so they can be properly concatenated during CFG
entity_data = {
"entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list),
"entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list),
"entity_masks": comfy.conds.CONDRegular(entity_masks_tensor),
"entity_prompt_emb": entity_prompt_emb_list,
"entity_prompt_emb_mask": entity_prompt_emb_mask_list,
"entity_masks": entity_masks_tensor,
}
conditioning_with_entities = node_helpers.conditioning_set_values(