mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
mask instead of image inputs for qwen eligen pipeline
This commit is contained in:
parent
0f4a141faf
commit
b0ade4bb85
@ -116,11 +116,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
io.Clip.Input("clip"),
|
||||
io.Conditioning.Input("global_conditioning"),
|
||||
io.Latent.Input("latent"),
|
||||
io.Image.Input("entity_mask_1", optional=True),
|
||||
io.Mask.Input("entity_mask_1", optional=True),
|
||||
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("entity_mask_2", optional=True),
|
||||
io.Mask.Input("entity_mask_2", optional=True),
|
||||
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("entity_mask_3", optional=True),
|
||||
io.Mask.Input("entity_mask_3", optional=True),
|
||||
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
|
||||
],
|
||||
outputs=[
|
||||
@ -196,31 +196,26 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
# Process spatial masks to latent space
|
||||
processed_entity_masks = []
|
||||
for i, (_, mask) in enumerate(valid_entities):
|
||||
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
|
||||
# MASK type format: [batch, height, width] (no channel dimension)
|
||||
# This is different from IMAGE type which is [batch, height, width, channels]
|
||||
mask_tensor = mask
|
||||
|
||||
# Log original mask dimensions
|
||||
original_shape = mask_tensor.shape
|
||||
if len(original_shape) == 3:
|
||||
if len(original_shape) == 2:
|
||||
# [height, width] - single mask without batch
|
||||
orig_h, orig_w = original_shape[0], original_shape[1]
|
||||
elif len(original_shape) == 4:
|
||||
# Add batch dimension: [1, height, width]
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
elif len(original_shape) == 3:
|
||||
# [batch, height, width] - standard MASK format
|
||||
orig_h, orig_w = original_shape[1], original_shape[2]
|
||||
else:
|
||||
orig_h, orig_w = original_shape[-2], original_shape[-1]
|
||||
raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]")
|
||||
|
||||
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
|
||||
|
||||
# Ensure mask is in [batch, channels, height, width] format for upscale
|
||||
if len(mask_tensor.shape) == 3:
|
||||
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
|
||||
mask_tensor = mask_tensor.unsqueeze(0)
|
||||
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
|
||||
# [batch, height, width, channels] -> [batch, channels, height, width]
|
||||
mask_tensor = mask_tensor.movedim(-1, 1)
|
||||
|
||||
# Take only first channel if multiple channels
|
||||
if mask_tensor.shape[1] > 1:
|
||||
mask_tensor = mask_tensor[:, 0:1, :, :]
|
||||
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
|
||||
# common_upscale expects [batch, channels, height, width]
|
||||
mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width]
|
||||
|
||||
# Resize to latent space dimensions using nearest neighbor
|
||||
resized_mask = comfy.utils.common_upscale(
|
||||
@ -238,13 +233,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||
# Log how many pixels are active in the mask
|
||||
active_pixels = (resized_mask > 0).sum().item()
|
||||
total_pixels = resized_mask.numel()
|
||||
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
|
||||
|
||||
processed_entity_masks.append(resized_mask)
|
||||
|
||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel)
|
||||
# No padding - handle dynamic number of entities
|
||||
entity_masks_tensor = torch.stack(processed_entity_masks, dim=1)
|
||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
|
||||
# We need to remove batch dim, stack, then add it back
|
||||
# Option 1: Squeeze batch dim from each mask
|
||||
processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
|
||||
entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W]
|
||||
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
|
||||
|
||||
# Extract global prompt embedding and mask from conditioning
|
||||
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user