From b0ade4bb85ad6834449b374ac3422e239da109d0 Mon Sep 17 00:00:00 2001 From: nolan4 Date: Fri, 24 Oct 2025 18:22:39 -0700 Subject: [PATCH] mask instead of image inputs for qwen eligen pipeline --- comfy_extras/nodes_qwen.py | 46 ++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 5ac48dc36..d90707a49 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -116,11 +116,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): io.Clip.Input("clip"), io.Conditioning.Input("global_conditioning"), io.Latent.Input("latent"), - io.Image.Input("entity_mask_1", optional=True), + io.Mask.Input("entity_mask_1", optional=True), io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""), - io.Image.Input("entity_mask_2", optional=True), + io.Mask.Input("entity_mask_2", optional=True), io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""), - io.Image.Input("entity_mask_3", optional=True), + io.Mask.Input("entity_mask_3", optional=True), io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""), ], outputs=[ @@ -196,31 +196,26 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Process spatial masks to latent space processed_entity_masks = [] for i, (_, mask) in enumerate(valid_entities): - # mask is expected to be [batch, height, width, channels] or [batch, height, width] + # MASK type format: [batch, height, width] (no channel dimension) + # This is different from IMAGE type which is [batch, height, width, channels] mask_tensor = mask # Log original mask dimensions original_shape = mask_tensor.shape - if len(original_shape) == 3: + if len(original_shape) == 2: + # [height, width] - single mask without batch orig_h, orig_w = original_shape[0], original_shape[1] - elif len(original_shape) == 4: + # Add batch dimension: [1, height, width] + mask_tensor = mask_tensor.unsqueeze(0) + elif len(original_shape) == 3: + # [batch, height, width] - standard MASK format orig_h, orig_w = original_shape[1], original_shape[2] else: - orig_h, orig_w = original_shape[-2], original_shape[-1] + raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]") - print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent") - - # Ensure mask is in [batch, channels, height, width] format for upscale - if len(mask_tensor.shape) == 3: - # [height, width, channels] -> [1, height, width, channels] (add batch dimension) - mask_tensor = mask_tensor.unsqueeze(0) - elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]: - # [batch, height, width, channels] -> [batch, channels, height, width] - mask_tensor = mask_tensor.movedim(-1, 1) - - # Take only first channel if multiple channels - if mask_tensor.shape[1] > 1: - mask_tensor = mask_tensor[:, 0:1, :, :] + # Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale + # common_upscale expects [batch, channels, height, width] + mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width] # Resize to latent space dimensions using nearest neighbor resized_mask = comfy.utils.common_upscale( @@ -238,13 +233,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode): # Log how many pixels are active in the mask active_pixels = (resized_mask > 0).sum().item() total_pixels = resized_mask.numel() - print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)") processed_entity_masks.append(resized_mask) - # Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel) - # No padding - handle dynamic number of entities - entity_masks_tensor = torch.stack(processed_entity_masks, dim=1) + # Stack masks: [batch, num_entities, 1, latent_height, latent_width] + # Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1) + # We need to remove batch dim, stack, then add it back + # Option 1: Squeeze batch dim from each mask + processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W] + entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W] + entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W] # Extract global prompt embedding and mask from conditioning # Conditioning format: [[cond_tensor, extra_dict]]