mask instead of image inputs for qwen eligen pipeline

This commit is contained in:
nolan4 2025-10-24 18:22:39 -07:00
parent 0f4a141faf
commit b0ade4bb85

View File

@ -116,11 +116,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
io.Clip.Input("clip"),
io.Conditioning.Input("global_conditioning"),
io.Latent.Input("latent"),
io.Image.Input("entity_mask_1", optional=True),
io.Mask.Input("entity_mask_1", optional=True),
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_2", optional=True),
io.Mask.Input("entity_mask_2", optional=True),
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_3", optional=True),
io.Mask.Input("entity_mask_3", optional=True),
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
],
outputs=[
@ -196,31 +196,26 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
# Process spatial masks to latent space
processed_entity_masks = []
for i, (_, mask) in enumerate(valid_entities):
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
# MASK type format: [batch, height, width] (no channel dimension)
# This is different from IMAGE type which is [batch, height, width, channels]
mask_tensor = mask
# Log original mask dimensions
original_shape = mask_tensor.shape
if len(original_shape) == 3:
if len(original_shape) == 2:
# [height, width] - single mask without batch
orig_h, orig_w = original_shape[0], original_shape[1]
elif len(original_shape) == 4:
# Add batch dimension: [1, height, width]
mask_tensor = mask_tensor.unsqueeze(0)
elif len(original_shape) == 3:
# [batch, height, width] - standard MASK format
orig_h, orig_w = original_shape[1], original_shape[2]
else:
orig_h, orig_w = original_shape[-2], original_shape[-1]
raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]")
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
# Ensure mask is in [batch, channels, height, width] format for upscale
if len(mask_tensor.shape) == 3:
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
mask_tensor = mask_tensor.unsqueeze(0)
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
# [batch, height, width, channels] -> [batch, channels, height, width]
mask_tensor = mask_tensor.movedim(-1, 1)
# Take only first channel if multiple channels
if mask_tensor.shape[1] > 1:
mask_tensor = mask_tensor[:, 0:1, :, :]
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
# common_upscale expects [batch, channels, height, width]
mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width]
# Resize to latent space dimensions using nearest neighbor
resized_mask = comfy.utils.common_upscale(
@ -238,13 +233,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
# Log how many pixels are active in the mask
active_pixels = (resized_mask > 0).sum().item()
total_pixels = resized_mask.numel()
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
processed_entity_masks.append(resized_mask)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel)
# No padding - handle dynamic number of entities
entity_masks_tensor = torch.stack(processed_entity_masks, dim=1)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
# We need to remove batch dim, stack, then add it back
# Option 1: Squeeze batch dim from each mask
processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W]
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
# Extract global prompt embedding and mask from conditioning
# Conditioning format: [[cond_tensor, extra_dict]]