mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
mask instead of image inputs for qwen eligen pipeline
This commit is contained in:
parent
0f4a141faf
commit
b0ade4bb85
@ -116,11 +116,11 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Conditioning.Input("global_conditioning"),
|
io.Conditioning.Input("global_conditioning"),
|
||||||
io.Latent.Input("latent"),
|
io.Latent.Input("latent"),
|
||||||
io.Image.Input("entity_mask_1", optional=True),
|
io.Mask.Input("entity_mask_1", optional=True),
|
||||||
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
|
||||||
io.Image.Input("entity_mask_2", optional=True),
|
io.Mask.Input("entity_mask_2", optional=True),
|
||||||
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
|
||||||
io.Image.Input("entity_mask_3", optional=True),
|
io.Mask.Input("entity_mask_3", optional=True),
|
||||||
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
|
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -196,31 +196,26 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# Process spatial masks to latent space
|
# Process spatial masks to latent space
|
||||||
processed_entity_masks = []
|
processed_entity_masks = []
|
||||||
for i, (_, mask) in enumerate(valid_entities):
|
for i, (_, mask) in enumerate(valid_entities):
|
||||||
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
|
# MASK type format: [batch, height, width] (no channel dimension)
|
||||||
|
# This is different from IMAGE type which is [batch, height, width, channels]
|
||||||
mask_tensor = mask
|
mask_tensor = mask
|
||||||
|
|
||||||
# Log original mask dimensions
|
# Log original mask dimensions
|
||||||
original_shape = mask_tensor.shape
|
original_shape = mask_tensor.shape
|
||||||
if len(original_shape) == 3:
|
if len(original_shape) == 2:
|
||||||
|
# [height, width] - single mask without batch
|
||||||
orig_h, orig_w = original_shape[0], original_shape[1]
|
orig_h, orig_w = original_shape[0], original_shape[1]
|
||||||
elif len(original_shape) == 4:
|
# Add batch dimension: [1, height, width]
|
||||||
|
mask_tensor = mask_tensor.unsqueeze(0)
|
||||||
|
elif len(original_shape) == 3:
|
||||||
|
# [batch, height, width] - standard MASK format
|
||||||
orig_h, orig_w = original_shape[1], original_shape[2]
|
orig_h, orig_w = original_shape[1], original_shape[2]
|
||||||
else:
|
else:
|
||||||
orig_h, orig_w = original_shape[-2], original_shape[-1]
|
raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]")
|
||||||
|
|
||||||
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
|
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
|
||||||
|
# common_upscale expects [batch, channels, height, width]
|
||||||
# Ensure mask is in [batch, channels, height, width] format for upscale
|
mask_tensor = mask_tensor.unsqueeze(1) # Add channel dimension: [batch, 1, height, width]
|
||||||
if len(mask_tensor.shape) == 3:
|
|
||||||
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
|
|
||||||
mask_tensor = mask_tensor.unsqueeze(0)
|
|
||||||
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
|
|
||||||
# [batch, height, width, channels] -> [batch, channels, height, width]
|
|
||||||
mask_tensor = mask_tensor.movedim(-1, 1)
|
|
||||||
|
|
||||||
# Take only first channel if multiple channels
|
|
||||||
if mask_tensor.shape[1] > 1:
|
|
||||||
mask_tensor = mask_tensor[:, 0:1, :, :]
|
|
||||||
|
|
||||||
# Resize to latent space dimensions using nearest neighbor
|
# Resize to latent space dimensions using nearest neighbor
|
||||||
resized_mask = comfy.utils.common_upscale(
|
resized_mask = comfy.utils.common_upscale(
|
||||||
@ -238,13 +233,16 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# Log how many pixels are active in the mask
|
# Log how many pixels are active in the mask
|
||||||
active_pixels = (resized_mask > 0).sum().item()
|
active_pixels = (resized_mask > 0).sum().item()
|
||||||
total_pixels = resized_mask.numel()
|
total_pixels = resized_mask.numel()
|
||||||
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
|
|
||||||
|
|
||||||
processed_entity_masks.append(resized_mask)
|
processed_entity_masks.append(resized_mask)
|
||||||
|
|
||||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width] (1 is selected channel)
|
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||||
# No padding - handle dynamic number of entities
|
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
|
||||||
entity_masks_tensor = torch.stack(processed_entity_masks, dim=1)
|
# We need to remove batch dim, stack, then add it back
|
||||||
|
# Option 1: Squeeze batch dim from each mask
|
||||||
|
processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
|
||||||
|
entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W]
|
||||||
|
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
|
||||||
|
|
||||||
# Extract global prompt embedding and mask from conditioning
|
# Extract global prompt embedding and mask from conditioning
|
||||||
# Conditioning format: [[cond_tensor, extra_dict]]
|
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user