From 1d9124203f313a22ac7715d44dc22950fc8875e9 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Thu, 23 Oct 2025 22:56:32 -0700 Subject: [PATCH] fixed application of entity-specific RoPE embeddings --- comfy/ldm/qwen_image/model.py | 140 +++++++--------------------------- comfy/model_base.py | 21 ++--- comfy_extras/nodes_qwen.py | 31 ++++---- 3 files changed, 47 insertions(+), 145 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 47fa7a5f6..7ac45c9a9 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -339,12 +339,10 @@ class Attention(nn.Module): joint_key = joint_key.permute(0, 2, 1, 3) joint_value = joint_value.permute(0, 2, 1, 3) - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") - print(f" - Query shape: {joint_query.shape}") - print(f" - Mask shape: {effective_mask.shape}") - print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") + print(f"[EliGen Debug Attention] Using PyTorch SDPA directly") + print(f" - Query shape: {joint_query.shape}") + print(f" - Mask shape: {effective_mask.shape}") + print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}") # Apply SDPA with mask (research-accurate) joint_hidden_states = torch.nn.functional.scaled_dot_product_attention( @@ -592,15 +590,14 @@ class QwenImageTransformer2DModel(nn.Module): # SECTION 1: Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] - all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb] + all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = torch.cat(all_prompt_emb, dim=1) - # SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope) + # SECTION 2: Build RoPE position embeddings # Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying) img_shapes = [(latents.shape[0], height//16, width//16)] - # Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE) - # Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()] + # Calculate sequence lengths for entities and global prompt entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask] # Handle None case in ComfyUI (None means no padding, all tokens valid) @@ -611,56 +608,27 @@ class QwenImageTransformer2DModel(nn.Module): global_seq_len = int(prompt_emb.shape[1]) # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) - # RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) txt_seq_lens = [global_seq_len] image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device) - # Create SEPARATE RoPE embeddings for each entity (EXACTLY like research) - # RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens] - entity_rotary_emb = [] + # Create SEPARATE RoPE embeddings for each entity + # Each entity gets its own positional encoding based on its sequence length + entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] + for entity_seq_len in entity_seq_lens] - import os - debug = os.environ.get("ELIGEN_DEBUG") - - for i, entity_seq_len in enumerate(entity_seq_lens): - # Pass as list for compatibility with research API - entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1] - entity_rotary_emb.append(entity_rope) - if debug: - print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}") - - if debug: - print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}") - print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE") - - # Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research) - # QwenEmbedRope returns 2D tensors with shape [seq_len, features] - # Entity ropes: [entity_seq_len, features] - # Global rope: [global_seq_len, features] - # Concatenate along dim=0 to get [total_seq_len, features] - # RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + # Concatenate entity RoPEs with global RoPE along sequence dimension + # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) - # Replace text part of tuple (EXACTLY like research) - # RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + # Replace text part of tuple with concatenated entity + global RoPE image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) - # Debug output for RoPE embeddings - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}") - print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}") - print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}") - print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}") - print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}") - print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}") - # SECTION 3: Prepare spatial masks repeat_dim = latents.shape[1] # 16 max_masks = entity_masks.shape[1] # N entities entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - # Pad masks to match padded latent dimensions (same as process_img does) + # Pad masks to match padded latent dimensions # entity_masks shape: [1, N, 16, H/8, W/8] # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] padded_h = height // 8 @@ -688,13 +656,6 @@ class QwenImageTransformer2DModel(nn.Module): seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) - # Debug: Check mask dimensions - import os - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}") - print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}") - print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}") - patched_masks = [] for i in range(N): patched_mask = rearrange( @@ -753,43 +714,6 @@ class QwenImageTransformer2DModel(nn.Module): attention_mask[attention_mask == 1] = 0 attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) - if debug: - print(f"\n[EliGen Debug Mask Values]") - print(f" Token ranges:") - for i in range(len(seq_lens)): - if i < len(seq_lens) - 1: - print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") - else: - print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})") - print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}") - - print(f"\n Checking Entity 0 connections:") - # Entity 0 to itself (should be 0) - e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]] - print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed") - - # Entity 0 to Entity 1 (should be -inf) - if len(seq_lens) > 2: - e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]] - print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked") - - # Entity 0 to Global (should be -inf) - e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]] - print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked") - - # Entity 0 to Image (should be partially blocked based on mask) - e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:] - print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked") - - # Image to Entity 0 (should match Entity 0 to Image, transposed) - img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]] - print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed") - - # Global to Image (should be fully allowed) - global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:] - print(f"\n Checking Global connections:") - print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed") - return all_prompt_emb, image_rotary_emb, attention_mask def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): @@ -848,22 +772,17 @@ class QwenImageTransformer2DModel(nn.Module): entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # import pdb; pdb.set_trace() - - # Debug logging (set ELIGEN_DEBUG=1 environment variable to enable) - import os - if os.environ.get("ELIGEN_DEBUG"): - if entity_prompt_emb is not None: - print(f"[EliGen Debug] Entity data found!") - print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") - print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") - print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") - # Check if this is positive or negative conditioning - cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") - else: - print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") + if entity_prompt_emb is not None: + print(f"[EliGen Debug] Entity data found!") + print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}") + print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}") + print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}") + # Check if this is positive or negative conditioning + cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] + print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}") + else: + print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}") # Branch: EliGen vs Standard path # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) @@ -878,11 +797,10 @@ class QwenImageTransformer2DModel(nn.Module): height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) - if os.environ.get("ELIGEN_DEBUG"): - print(f"[EliGen Debug] Original latent shape: {x.shape}") - print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") - print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") - print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") + print(f"[EliGen Debug] Original latent shape: {x.shape}") + print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}") + print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}") + print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}") # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( diff --git a/comfy/model_base.py b/comfy/model_base.py index 050e10c98..869fd75bd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -119,7 +119,6 @@ def convert_tensor(extra, dtype, device): extra = comfy.model_management.cast_to_device(extra, device, None) return extra - class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel): super().__init__() @@ -381,7 +380,6 @@ class BaseModel(torch.nn.Module): def extra_conds_shapes(self, **kwargs): return {} - def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] @@ -477,7 +475,6 @@ class SDXL(BaseModel): flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) - class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) @@ -554,7 +551,6 @@ class SV3D_p(SVD_img2vid): out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) return torch.cat(out, dim=1) - class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) @@ -638,13 +634,11 @@ class IP2P: image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) - class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image - class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) @@ -694,7 +688,6 @@ class StableCascade_C(BaseModel): out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn) return out - class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) @@ -714,7 +707,6 @@ class StableCascade_B(BaseModel): out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,))) return out - class SD3(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper) @@ -729,7 +721,6 @@ class SD3(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT) @@ -741,7 +732,6 @@ class AuraFlow(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) @@ -780,7 +770,6 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd - class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT) @@ -914,7 +903,6 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out - class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) @@ -1166,7 +1154,6 @@ class WAN21(BaseModel): return out - class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) @@ -1466,15 +1453,17 @@ class QwenImage(BaseModel): # Handle EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) if entity_prompt_emb is not None: - out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node + out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) if entity_prompt_emb_mask is not None: - out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node + out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask) entity_masks = kwargs.get("entity_masks", None) if entity_masks is not None: - out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node + out['entity_masks'] = comfy.conds.CONDRegular(entity_masks) + + # import pdb; pdb.set_trace() return out diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 184fdfcff..d8ebbf462 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -176,17 +176,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_prompt_emb_list = [] entity_prompt_emb_mask_list = [] - for entity_prompt, _ in valid_entities: + for entity_prompt, _ in valid_entities: # mask not used at this point entity_tokens = clip.tokenize(entity_prompt) - entity_cond = clip.encode_from_tokens_scheduled(entity_tokens) - - # Extract embeddings and masks from conditioning - # Conditioning format: [[cond_tensor, extra_dict], ...] - entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584] - extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.) - - # Extract attention mask from metadata dict - entity_prompt_emb_mask = extra_dict.get("attention_mask", None) + entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True) + + entity_prompt_emb = entity_cond_dict["cond"] + entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None) # If no attention mask in extra_dict, create one (all True) if entity_prompt_emb_mask is None: @@ -194,11 +189,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len), dtype=torch.bool, device=entity_prompt_emb.device) + entity_prompt_emb_list.append(entity_prompt_emb) entity_prompt_emb_mask_list.append(entity_prompt_emb_mask) # Process spatial masks to latent space - processed_masks = [] + processed_entity_masks = [] for i, (_, mask) in enumerate(valid_entities): # mask is expected to be [batch, height, width, channels] or [batch, height, width] mask_tensor = mask @@ -244,11 +240,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): total_pixels = resized_mask.numel() print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)") - processed_masks.append(resized_mask) + processed_entity_masks.append(resized_mask) - # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel) # No padding - handle dynamic number of entities - entity_masks_tensor = torch.stack(processed_masks, dim=1) + entity_masks_tensor = torch.stack(processed_entity_masks, dim=1) # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]] @@ -263,11 +259,10 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): dtype=torch.bool, device=global_prompt_emb.device) # Attach entity data to conditioning using conditioning_set_values - # Wrap lists in CONDList so they can be properly concatenated during CFG entity_data = { - "entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list), - "entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list), - "entity_masks": comfy.conds.CONDRegular(entity_masks_tensor), + "entity_prompt_emb": entity_prompt_emb_list, + "entity_prompt_emb_mask": entity_prompt_emb_mask_list, + "entity_masks": entity_masks_tensor, } conditioning_with_entities = node_helpers.conditioning_set_values(