mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
fixed application of entity-specific RoPE embeddings
This commit is contained in:
parent
ffe3503370
commit
1d9124203f
@ -339,12 +339,10 @@ class Attention(nn.Module):
|
||||
joint_key = joint_key.permute(0, 2, 1, 3)
|
||||
joint_value = joint_value.permute(0, 2, 1, 3)
|
||||
|
||||
import os
|
||||
if os.environ.get("ELIGEN_DEBUG"):
|
||||
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
|
||||
print(f" - Query shape: {joint_query.shape}")
|
||||
print(f" - Mask shape: {effective_mask.shape}")
|
||||
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
|
||||
print(f"[EliGen Debug Attention] Using PyTorch SDPA directly")
|
||||
print(f" - Query shape: {joint_query.shape}")
|
||||
print(f" - Mask shape: {effective_mask.shape}")
|
||||
print(f" - Mask min/max: {effective_mask.min()} / {effective_mask.max():.2f}")
|
||||
|
||||
# Apply SDPA with mask (research-accurate)
|
||||
joint_hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
@ -592,15 +590,14 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
# SECTION 1: Concatenate entity + global prompts
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(p)) for p in all_prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# SECTION 2: Build RoPE position embeddings (RESEARCH-ACCURATE using QwenEmbedRope)
|
||||
# SECTION 2: Build RoPE position embeddings
|
||||
# Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying)
|
||||
img_shapes = [(latents.shape[0], height//16, width//16)]
|
||||
|
||||
# Calculate sequence lengths for entities and global prompt (RESEARCH-ACCURATE)
|
||||
# Research code: seq_lens = [mask_.sum(dim=1).item() for mask_ in entity_prompt_emb_mask] + [prompt_emb_mask.sum(dim=1).item()]
|
||||
# Calculate sequence lengths for entities and global prompt
|
||||
entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask]
|
||||
|
||||
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
||||
@ -611,56 +608,27 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
global_seq_len = int(prompt_emb.shape[1])
|
||||
|
||||
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
||||
# RESEARCH: image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
txt_seq_lens = [global_seq_len]
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
||||
|
||||
# Create SEPARATE RoPE embeddings for each entity (EXACTLY like research)
|
||||
# RESEARCH: entity_rotary_emb = [self.pos_embed(img_shapes, entity_seq_len, device=latents.device)[1] for entity_seq_len in entity_seq_lens]
|
||||
entity_rotary_emb = []
|
||||
# Create SEPARATE RoPE embeddings for each entity
|
||||
# Each entity gets its own positional encoding based on its sequence length
|
||||
entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
|
||||
for entity_seq_len in entity_seq_lens]
|
||||
|
||||
import os
|
||||
debug = os.environ.get("ELIGEN_DEBUG")
|
||||
|
||||
for i, entity_seq_len in enumerate(entity_seq_lens):
|
||||
# Pass as list for compatibility with research API
|
||||
entity_rope = self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
|
||||
entity_rotary_emb.append(entity_rope)
|
||||
if debug:
|
||||
print(f"[EliGen Debug RoPE] Entity {i} RoPE shape: {entity_rope.shape}, seq_len: {entity_seq_len}")
|
||||
|
||||
if debug:
|
||||
print(f"[EliGen Debug RoPE] Global RoPE shape: {image_rotary_emb[1].shape}, seq_len: {global_seq_len}")
|
||||
print(f"[EliGen Debug RoPE] Attempting to concatenate {len(entity_rotary_emb)} entity RoPEs + 1 global RoPE")
|
||||
|
||||
# Concatenate entity RoPEs with global RoPE along sequence dimension (EXACTLY like research)
|
||||
# QwenEmbedRope returns 2D tensors with shape [seq_len, features]
|
||||
# Entity ropes: [entity_seq_len, features]
|
||||
# Global rope: [global_seq_len, features]
|
||||
# Concatenate along dim=0 to get [total_seq_len, features]
|
||||
# RESEARCH: txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
# Concatenate entity RoPEs with global RoPE along sequence dimension
|
||||
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
|
||||
# Replace text part of tuple (EXACTLY like research)
|
||||
# RESEARCH: image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
# Replace text part of tuple with concatenated entity + global RoPE
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
|
||||
# Debug output for RoPE embeddings
|
||||
import os
|
||||
if os.environ.get("ELIGEN_DEBUG"):
|
||||
print(f"[EliGen Debug RoPE] Number of entities: {len(entity_seq_lens)}")
|
||||
print(f"[EliGen Debug RoPE] Entity sequence lengths: {entity_seq_lens}")
|
||||
print(f"[EliGen Debug RoPE] Global sequence length: {global_seq_len}")
|
||||
print(f"[EliGen Debug RoPE] img_rotary_emb (tuple[0]) shape: {image_rotary_emb[0].shape}")
|
||||
print(f"[EliGen Debug RoPE] txt_rotary_emb (tuple[1]) shape: {image_rotary_emb[1].shape}")
|
||||
print(f"[EliGen Debug RoPE] Total text seq length: {sum(entity_seq_lens) + global_seq_len}")
|
||||
|
||||
# SECTION 3: Prepare spatial masks
|
||||
repeat_dim = latents.shape[1] # 16
|
||||
max_masks = entity_masks.shape[1] # N entities
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
|
||||
# Pad masks to match padded latent dimensions (same as process_img does)
|
||||
# Pad masks to match padded latent dimensions
|
||||
# entity_masks shape: [1, N, 16, H/8, W/8]
|
||||
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
||||
padded_h = height // 8
|
||||
@ -688,13 +656,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
seq_lens = entity_seq_lens + [global_seq_len]
|
||||
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||
|
||||
# Debug: Check mask dimensions
|
||||
import os
|
||||
if os.environ.get("ELIGEN_DEBUG"):
|
||||
print(f"[EliGen Debug Patchify] entity_masks[0] shape: {entity_masks[0].shape}")
|
||||
print(f"[EliGen Debug Patchify] height={height}, width={width}, height//16={height//16}, width//16={width//16}")
|
||||
print(f"[EliGen Debug Patchify] Expected mask size: {height//16 * 2} x {width//16 * 2} = {(height//16) * 2} x {(width//16) * 2}")
|
||||
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
patched_mask = rearrange(
|
||||
@ -753,43 +714,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
|
||||
if debug:
|
||||
print(f"\n[EliGen Debug Mask Values]")
|
||||
print(f" Token ranges:")
|
||||
for i in range(len(seq_lens)):
|
||||
if i < len(seq_lens) - 1:
|
||||
print(f" - Entity {i} tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
|
||||
else:
|
||||
print(f" - Global tokens: {cumsum[i]}-{cumsum[i+1]-1} (length: {seq_lens[i]})")
|
||||
print(f" - Image tokens: {sum(seq_lens)}-{total_seq_len-1}")
|
||||
|
||||
print(f"\n Checking Entity 0 connections:")
|
||||
# Entity 0 to itself (should be 0)
|
||||
e0_to_e0 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[0]:cumsum[1]]
|
||||
print(f" - Entity0->Entity0: {(e0_to_e0 == 0).sum()}/{e0_to_e0.numel()} allowed")
|
||||
|
||||
# Entity 0 to Entity 1 (should be -inf)
|
||||
if len(seq_lens) > 2:
|
||||
e0_to_e1 = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[1]:cumsum[2]]
|
||||
print(f" - Entity0->Entity1: {(e0_to_e1 == float('-inf')).sum()}/{e0_to_e1.numel()} blocked")
|
||||
|
||||
# Entity 0 to Global (should be -inf)
|
||||
e0_to_global = attention_mask[0, 0, cumsum[0]:cumsum[1], cumsum[-2]:cumsum[-1]]
|
||||
print(f" - Entity0->Global: {(e0_to_global == float('-inf')).sum()}/{e0_to_global.numel()} blocked")
|
||||
|
||||
# Entity 0 to Image (should be partially blocked based on mask)
|
||||
e0_to_img = attention_mask[0, 0, cumsum[0]:cumsum[1], image_start:]
|
||||
print(f" - Entity0->Image: {(e0_to_img == 0).sum()}/{e0_to_img.numel()} allowed, {(e0_to_img == float('-inf')).sum()} blocked")
|
||||
|
||||
# Image to Entity 0 (should match Entity 0 to Image, transposed)
|
||||
img_to_e0 = attention_mask[0, 0, image_start:, cumsum[0]:cumsum[1]]
|
||||
print(f" - Image->Entity0: {(img_to_e0 == 0).sum()}/{img_to_e0.numel()} allowed")
|
||||
|
||||
# Global to Image (should be fully allowed)
|
||||
global_to_img = attention_mask[0, 0, cumsum[-2]:cumsum[-1], image_start:]
|
||||
print(f"\n Checking Global connections:")
|
||||
print(f" - Global->Image: {(global_to_img == 0).sum()}/{global_to_img.numel()} allowed")
|
||||
|
||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||
|
||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
@ -848,22 +772,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
|
||||
# Debug logging (set ELIGEN_DEBUG=1 environment variable to enable)
|
||||
import os
|
||||
if os.environ.get("ELIGEN_DEBUG"):
|
||||
if entity_prompt_emb is not None:
|
||||
print(f"[EliGen Debug] Entity data found!")
|
||||
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
|
||||
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
|
||||
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
|
||||
# Check if this is positive or negative conditioning
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
|
||||
else:
|
||||
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
|
||||
if entity_prompt_emb is not None:
|
||||
print(f"[EliGen Debug] Entity data found!")
|
||||
print(f" - entity_prompt_emb type: {type(entity_prompt_emb)}, len: {len(entity_prompt_emb) if isinstance(entity_prompt_emb, list) else 'N/A'}")
|
||||
print(f" - entity_masks shape: {entity_masks.shape if entity_masks is not None else 'None'}")
|
||||
print(f" - Number of entities: {entity_masks.shape[1] if entity_masks is not None else 'Unknown'}")
|
||||
# Check if this is positive or negative conditioning
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
print(f" - Conditioning type: {['uncond' if c == 1 else 'cond' for c in cond_or_uncond]}")
|
||||
else:
|
||||
print(f"[EliGen Debug] No entity data in kwargs. Keys: {list(kwargs.keys())}")
|
||||
|
||||
# Branch: EliGen vs Standard path
|
||||
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
|
||||
@ -878,11 +797,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
|
||||
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
|
||||
|
||||
if os.environ.get("ELIGEN_DEBUG"):
|
||||
print(f"[EliGen Debug] Original latent shape: {x.shape}")
|
||||
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
|
||||
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
|
||||
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
|
||||
print(f"[EliGen Debug] Original latent shape: {x.shape}")
|
||||
print(f"[EliGen Debug] Padded latent shape (orig_shape): {orig_shape}")
|
||||
print(f"[EliGen Debug] Calculated pixel dimensions: {height}x{width}")
|
||||
print(f"[EliGen Debug] Expected patches: {height//16}x{width//16}")
|
||||
|
||||
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
|
||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||
|
||||
@ -119,7 +119,6 @@ def convert_tensor(extra, dtype, device):
|
||||
extra = comfy.model_management.cast_to_device(extra, device, None)
|
||||
return extra
|
||||
|
||||
|
||||
class BaseModel(torch.nn.Module):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
|
||||
super().__init__()
|
||||
@ -381,7 +380,6 @@ class BaseModel(torch.nn.Module):
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
return {}
|
||||
|
||||
|
||||
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
|
||||
adm_inputs = []
|
||||
weights = []
|
||||
@ -477,7 +475,6 @@ class SDXL(BaseModel):
|
||||
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
|
||||
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
|
||||
|
||||
|
||||
class SVD_img2vid(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -554,7 +551,6 @@ class SV3D_p(SVD_img2vid):
|
||||
out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out))
|
||||
return torch.cat(out, dim=1)
|
||||
|
||||
|
||||
class Stable_Zero123(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -638,13 +634,11 @@ class IP2P:
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return self.process_ip2p_image_in(image)
|
||||
|
||||
|
||||
class SD15_instructpix2pix(IP2P, BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
self.process_ip2p_image_in = lambda image: image
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(IP2P, SDXL):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
@ -694,7 +688,6 @@ class StableCascade_C(BaseModel):
|
||||
out['clip_text'] = comfy.conds.CONDCrossAttn(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableCascade_B(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||
@ -714,7 +707,6 @@ class StableCascade_B(BaseModel):
|
||||
out["sca"] = comfy.conds.CONDRegular(torch.zeros((1,)))
|
||||
return out
|
||||
|
||||
|
||||
class SD3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper)
|
||||
@ -729,7 +721,6 @@ class SD3(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class AuraFlow(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.aura.mmdit.MMDiT)
|
||||
@ -741,7 +732,6 @@ class AuraFlow(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class StableAudio1(BaseModel):
|
||||
def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer)
|
||||
@ -780,7 +770,6 @@ class StableAudio1(BaseModel):
|
||||
sd["{}{}".format(k, l)] = s[l]
|
||||
return sd
|
||||
|
||||
|
||||
class HunyuanDiT(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
|
||||
@ -914,7 +903,6 @@ class Flux(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@ -1166,7 +1154,6 @@ class WAN21(BaseModel):
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
@ -1466,15 +1453,17 @@ class QwenImage(BaseModel):
|
||||
# Handle EliGen entity data
|
||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||
if entity_prompt_emb is not None:
|
||||
out['entity_prompt_emb'] = entity_prompt_emb # Already wrapped in CONDList by node
|
||||
out['entity_prompt_emb'] = comfy.conds.CONDList(entity_prompt_emb)
|
||||
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
if entity_prompt_emb_mask is not None:
|
||||
out['entity_prompt_emb_mask'] = entity_prompt_emb_mask # Already wrapped in CONDList by node
|
||||
out['entity_prompt_emb_mask'] = comfy.conds.CONDList(entity_prompt_emb_mask)
|
||||
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
if entity_masks is not None:
|
||||
out['entity_masks'] = entity_masks # Already wrapped in CONDRegular by node
|
||||
out['entity_masks'] = comfy.conds.CONDRegular(entity_masks)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -176,17 +176,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
entity_prompt_emb_list = []
|
||||
entity_prompt_emb_mask_list = []
|
||||
|
||||
for entity_prompt, _ in valid_entities:
|
||||
for entity_prompt, _ in valid_entities: # mask not used at this point
|
||||
entity_tokens = clip.tokenize(entity_prompt)
|
||||
entity_cond = clip.encode_from_tokens_scheduled(entity_tokens)
|
||||
entity_cond_dict = clip.encode_from_tokens(entity_tokens, return_pooled=True, return_dict=True)
|
||||
|
||||
# Extract embeddings and masks from conditioning
|
||||
# Conditioning format: [[cond_tensor, extra_dict], ...]
|
||||
entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584]
|
||||
extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.)
|
||||
|
||||
# Extract attention mask from metadata dict
|
||||
entity_prompt_emb_mask = extra_dict.get("attention_mask", None)
|
||||
entity_prompt_emb = entity_cond_dict["cond"]
|
||||
entity_prompt_emb_mask = entity_cond_dict.get("attention_mask", None)
|
||||
|
||||
# If no attention mask in extra_dict, create one (all True)
|
||||
if entity_prompt_emb_mask is None:
|
||||
@ -194,11 +189,12 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
|
||||
dtype=torch.bool, device=entity_prompt_emb.device)
|
||||
|
||||
|
||||
entity_prompt_emb_list.append(entity_prompt_emb)
|
||||
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
|
||||
|
||||
# Process spatial masks to latent space
|
||||
processed_masks = []
|
||||
processed_entity_masks = []
|
||||
for i, (_, mask) in enumerate(valid_entities):
|
||||
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
|
||||
mask_tensor = mask
|
||||
@ -244,11 +240,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
total_pixels = resized_mask.numel()
|
||||
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
|
||||
|
||||
processed_masks.append(resized_mask)
|
||||
processed_entity_masks.append(resized_mask)
|
||||
|
||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel)
|
||||
# No padding - handle dynamic number of entities
|
||||
entity_masks_tensor = torch.stack(processed_masks, dim=1)
|
||||
entity_masks_tensor = torch.stack(processed_entity_masks, dim=1)
|
||||
|
||||
# Extract global prompt embedding and mask from conditioning
|
||||
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||
@ -263,11 +259,10 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
dtype=torch.bool, device=global_prompt_emb.device)
|
||||
|
||||
# Attach entity data to conditioning using conditioning_set_values
|
||||
# Wrap lists in CONDList so they can be properly concatenated during CFG
|
||||
entity_data = {
|
||||
"entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list),
|
||||
"entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list),
|
||||
"entity_masks": comfy.conds.CONDRegular(entity_masks_tensor),
|
||||
"entity_prompt_emb": entity_prompt_emb_list,
|
||||
"entity_prompt_emb_mask": entity_prompt_emb_mask_list,
|
||||
"entity_masks": entity_masks_tensor,
|
||||
}
|
||||
|
||||
conditioning_with_entities = node_helpers.conditioning_set_values(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user