mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
qwen eligen batch size > 1 fix
This commit is contained in:
parent
b0ade4bb85
commit
b222265628
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from einops import repeat, rearrange
|
from einops import repeat, rearrange
|
||||||
|
|
||||||
@ -12,6 +13,8 @@ from comfy.ldm.flux.layers import EmbedND
|
|||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class QwenEmbedRope(nn.Module):
|
class QwenEmbedRope(nn.Module):
|
||||||
"""RoPE implementation for EliGen.
|
"""RoPE implementation for EliGen.
|
||||||
@ -269,9 +272,6 @@ class Attention(nn.Module):
|
|||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
### NEW
|
|
||||||
#################################################
|
|
||||||
|
|
||||||
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
# Handle both tuple (EliGen) and single tensor (standard) RoPE formats
|
||||||
if isinstance(image_rotary_emb, tuple):
|
if isinstance(image_rotary_emb, tuple):
|
||||||
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
# EliGen path: Apply RoPE BEFORE concatenation (research-accurate)
|
||||||
@ -303,6 +303,7 @@ class Attention(nn.Module):
|
|||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
|
# Apply EliGen attention mask if present
|
||||||
effective_mask = attention_mask
|
effective_mask = attention_mask
|
||||||
if transformer_options is not None:
|
if transformer_options is not None:
|
||||||
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
eligen_mask = transformer_options.get("eligen_attention_mask", None)
|
||||||
@ -312,11 +313,12 @@ class Attention(nn.Module):
|
|||||||
# Validate shape
|
# Validate shape
|
||||||
expected_seq = joint_query.shape[1]
|
expected_seq = joint_query.shape[1]
|
||||||
if eligen_mask.shape[-1] != expected_seq:
|
if eligen_mask.shape[-1] != expected_seq:
|
||||||
raise ValueError(f"EliGen mask shape {eligen_mask.shape} doesn't match sequence length {expected_seq}")
|
raise ValueError(
|
||||||
|
f"EliGen attention mask shape mismatch: {eligen_mask.shape} "
|
||||||
#################################################
|
f"doesn't match sequence length {expected_seq}"
|
||||||
|
)
|
||||||
|
|
||||||
# Standard path: Use ComfyUI's optimized attention
|
# Use ComfyUI's optimized attention
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
joint_query = joint_query.flatten(start_dim=2)
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
joint_key = joint_key.flatten(start_dim=2)
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
joint_value = joint_value.flatten(start_dim=2)
|
||||||
@ -443,8 +445,12 @@ class LastLayer(nn.Module):
|
|||||||
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
x = torch.addcmul(shift[:, None, :], self.norm(x), (1 + scale)[:, None, :])
|
||||||
return x
|
return x
|
||||||
|
|
||||||
### NEW changes
|
|
||||||
class QwenImageTransformer2DModel(nn.Module):
|
class QwenImageTransformer2DModel(nn.Module):
|
||||||
|
# Constants for EliGen processing
|
||||||
|
LATENT_TO_PIXEL_RATIO = 8 # Latents are 8x downsampled from pixel space
|
||||||
|
PATCH_TO_LATENT_RATIO = 2 # 2x2 patches in latent space
|
||||||
|
PATCH_TO_PIXEL_RATIO = 16 # Combined: 2x2 patches on 8x downsampled latents = 16x in pixel space
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
patch_size: int = 2,
|
patch_size: int = 2,
|
||||||
@ -540,8 +546,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
|
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
|
||||||
entity_prompt_emb_mask: List[[1, L_i]]
|
entity_prompt_emb_mask: List[[1, L_i]]
|
||||||
entity_masks: [1, N, 1, H/8, W/8]
|
entity_masks: [1, N, 1, H/8, W/8]
|
||||||
height: int
|
height: int (padded pixel height)
|
||||||
width: int
|
width: int (padded pixel width)
|
||||||
image: [B, patches, 64] - Patchified latents
|
image: [B, patches, 64] - Patchified latents
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -549,6 +555,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
image_rotary_emb: RoPE embeddings
|
image_rotary_emb: RoPE embeddings
|
||||||
attention_mask: [1, 1, total_seq, total_seq]
|
attention_mask: [1, 1, total_seq, total_seq]
|
||||||
"""
|
"""
|
||||||
|
num_entities = len(entity_prompt_emb)
|
||||||
|
batch_size = latents.shape[0]
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image "
|
||||||
|
f"(latents: {latents.shape}, batch_size: {batch_size})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate batch consistency (all batches should have same sequence lengths)
|
||||||
|
# This is a ComfyUI requirement - batched prompts must have uniform padding
|
||||||
|
if batch_size > 1:
|
||||||
|
logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility")
|
||||||
|
|
||||||
# SECTION 1: Concatenate entity + global prompts
|
# SECTION 1: Concatenate entity + global prompts
|
||||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||||
@ -556,45 +573,63 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||||
|
|
||||||
# SECTION 2: Build RoPE position embeddings
|
# SECTION 2: Build RoPE position embeddings
|
||||||
# Calculate img_shapes for RoPE (batch, height//16, width//16 for images in latent space after patchifying)
|
# For EliGen, we create RoPE for ONE batch element's dimensions
|
||||||
img_shapes = [(latents.shape[0], height//16, width//16)]
|
# The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim
|
||||||
|
patch_h = height // self.PATCH_TO_PIXEL_RATIO
|
||||||
|
patch_w = width // self.PATCH_TO_PIXEL_RATIO
|
||||||
|
|
||||||
|
# Create RoPE for a single image (frame=1 for images, not video)
|
||||||
|
# This will broadcast across all batch elements automatically
|
||||||
|
img_shapes_single = [(1, patch_h, patch_w)]
|
||||||
|
|
||||||
# Calculate sequence lengths for entities and global prompt
|
# Calculate sequence lengths for entities and global prompt
|
||||||
entity_seq_lens = [int(mask.sum(dim=1).item()) for mask in entity_prompt_emb_mask]
|
# Use [0] to get first batch element (all batches should have same sequence lengths)
|
||||||
|
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
|
||||||
|
|
||||||
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
||||||
if prompt_emb_mask is not None:
|
if prompt_emb_mask is not None:
|
||||||
global_seq_len = int(prompt_emb_mask.sum(dim=1).item())
|
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
|
||||||
else:
|
else:
|
||||||
# No mask = no padding, use full sequence length
|
# No mask = no padding, use full sequence length
|
||||||
global_seq_len = int(prompt_emb.shape[1])
|
global_seq_len = int(prompt_emb.shape[1])
|
||||||
|
|
||||||
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
||||||
|
# We pass a single shape, not repeated for batch, because RoPE will broadcast
|
||||||
txt_seq_lens = [global_seq_len]
|
txt_seq_lens = [global_seq_len]
|
||||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
|
image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device)
|
||||||
|
|
||||||
# Create SEPARATE RoPE embeddings for each entity
|
# Create SEPARATE RoPE embeddings for each entity
|
||||||
# Each entity gets its own positional encoding based on its sequence length
|
# Each entity gets its own positional encoding based on its sequence length
|
||||||
entity_rotary_emb = [self.pos_embed(img_shapes, [entity_seq_len], device=latents.device)[1]
|
# We only need to create these once since they're the same for all batch elements
|
||||||
|
entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1]
|
||||||
for entity_seq_len in entity_seq_lens]
|
for entity_seq_len in entity_seq_lens]
|
||||||
|
|
||||||
# Concatenate entity RoPEs with global RoPE along sequence dimension
|
# Concatenate entity RoPEs with global RoPE along sequence dimension
|
||||||
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
|
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
|
||||||
|
# This creates the RoPE for ONE batch element's sequence
|
||||||
|
# Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...]
|
||||||
|
# and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically
|
||||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] RoPE created for single batch element - "
|
||||||
|
f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} "
|
||||||
|
f"(both will broadcast across batch_size={batch_size})"
|
||||||
|
)
|
||||||
|
|
||||||
# Replace text part of tuple with concatenated entity + global RoPE
|
# Replace text part of tuple with concatenated entity + global RoPE
|
||||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||||
|
|
||||||
# SECTION 3: Prepare spatial masks
|
# SECTION 3: Prepare spatial masks
|
||||||
repeat_dim = latents.shape[1] # 16
|
repeat_dim = latents.shape[1] # 16 (latent channels)
|
||||||
max_masks = entity_masks.shape[1] # N entities
|
max_masks = entity_masks.shape[1] # N entities
|
||||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
|
||||||
# Pad masks to match padded latent dimensions
|
# Pad masks to match padded latent dimensions
|
||||||
# entity_masks shape: [1, N, 16, H/8, W/8]
|
# entity_masks shape: [1, N, 16, H/8, W/8]
|
||||||
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
||||||
padded_h = height // 8
|
padded_h = height // self.LATENT_TO_PIXEL_RATIO
|
||||||
padded_w = width // 8
|
padded_w = width // self.LATENT_TO_PIXEL_RATIO
|
||||||
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||||
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
||||||
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
||||||
@ -602,6 +637,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
# Pad each entity mask
|
# Pad each entity mask
|
||||||
pad_h = padded_h - entity_masks.shape[3]
|
pad_h = padded_h - entity_masks.shape[3]
|
||||||
pad_w = padded_w - entity_masks.shape[4]
|
pad_w = padded_w - entity_masks.shape[4]
|
||||||
|
logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions")
|
||||||
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||||
|
|
||||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
@ -617,12 +653,20 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
seq_lens = entity_seq_lens + [global_seq_len]
|
seq_lens = entity_seq_lens + [global_seq_len]
|
||||||
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] Building attention mask: "
|
||||||
|
f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})"
|
||||||
|
)
|
||||||
|
|
||||||
patched_masks = []
|
patched_masks = []
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
patched_mask = rearrange(
|
patched_mask = rearrange(
|
||||||
entity_masks[i],
|
entity_masks[i],
|
||||||
"B C (H P) (W Q) -> B (H W) (C P Q)",
|
"B C (H P) (W Q) -> B (H W) (C P Q)",
|
||||||
H=height//16, W=width//16, P=2, Q=2
|
H=height // self.PATCH_TO_PIXEL_RATIO,
|
||||||
|
W=width // self.PATCH_TO_PIXEL_RATIO,
|
||||||
|
P=self.PATCH_TO_LATENT_RATIO,
|
||||||
|
Q=self.PATCH_TO_LATENT_RATIO
|
||||||
)
|
)
|
||||||
patched_masks.append(patched_mask)
|
patched_masks.append(patched_mask)
|
||||||
|
|
||||||
@ -671,10 +715,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
|
|
||||||
# SECTION 6: Convert to additive bias
|
# SECTION 6: Convert to additive bias
|
||||||
attention_mask = attention_mask.float()
|
attention_mask = attention_mask.float()
|
||||||
|
num_valid_connections = (attention_mask == 1).sum().item()
|
||||||
attention_mask[attention_mask == 0] = float('-inf')
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
attention_mask[attention_mask == 1] = 0
|
attention_mask[attention_mask == 1] = 0
|
||||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
|
||||||
|
f"valid_connections={num_valid_connections}/{total_seq_len * total_seq_len}"
|
||||||
|
)
|
||||||
|
|
||||||
return all_prompt_emb, image_rotary_emb, attention_mask
|
return all_prompt_emb, image_rotary_emb, attention_mask
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||||
|
|||||||
@ -3,9 +3,13 @@ import comfy.utils
|
|||||||
import comfy.conds
|
import comfy.conds
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TextEncodeQwenImageEdit(io.ComfyNode):
|
class TextEncodeQwenImageEdit(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -105,8 +109,31 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode):
|
|||||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||||
return io.NodeOutput(conditioning)
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
################ NEW
|
|
||||||
class TextEncodeQwenImageEliGen(io.ComfyNode):
|
class TextEncodeQwenImageEliGen(io.ComfyNode):
|
||||||
|
"""
|
||||||
|
Entity-Level Image Generation (EliGen) conditioning node for Qwen Image model.
|
||||||
|
|
||||||
|
Allows specifying different prompts for different spatial regions using masks.
|
||||||
|
Each entity (mask + prompt pair) will only influence its masked region through
|
||||||
|
spatial attention masking.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Supports up to 3 entities per generation
|
||||||
|
- Spatial attention masks prevent cross-entity contamination
|
||||||
|
- Separate RoPE embeddings per entity (research-accurate)
|
||||||
|
- Falls back to standard generation if no entities provided
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
1. Create spatial masks using LoadImageMask (white=entity, black=background)
|
||||||
|
2. Use 'red', 'green', or 'blue' channel (NOT 'alpha' - it gets inverted)
|
||||||
|
3. Provide entity-specific prompts for each masked region
|
||||||
|
|
||||||
|
Based on DiffSynth Studio: https://github.com/modelscope/DiffSynth-Studio
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Qwen Image model uses 2x2 patches on latents (which are 8x downsampled from pixels)
|
||||||
|
PATCH_SIZE = 2
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -129,8 +156,18 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None,
|
def execute(
|
||||||
entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput:
|
cls,
|
||||||
|
clip,
|
||||||
|
global_conditioning,
|
||||||
|
latent,
|
||||||
|
entity_prompt_1: str = "",
|
||||||
|
entity_mask_1: Optional[torch.Tensor] = None,
|
||||||
|
entity_prompt_2: str = "",
|
||||||
|
entity_mask_2: Optional[torch.Tensor] = None,
|
||||||
|
entity_prompt_3: str = "",
|
||||||
|
entity_mask_3: Optional[torch.Tensor] = None
|
||||||
|
) -> io.NodeOutput:
|
||||||
|
|
||||||
# Extract dimensions from latent tensor
|
# Extract dimensions from latent tensor
|
||||||
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
|
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
|
||||||
@ -139,10 +176,9 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
|
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
|
||||||
|
|
||||||
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
|
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
|
||||||
# The model pads latents to be multiples of patch_size (2 for Qwen)
|
# The model pads latents to be multiples of PATCH_SIZE
|
||||||
patch_size = 2
|
pad_h = (cls.PATCH_SIZE - unpadded_latent_height % cls.PATCH_SIZE) % cls.PATCH_SIZE
|
||||||
pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size
|
pad_w = (cls.PATCH_SIZE - unpadded_latent_width % cls.PATCH_SIZE) % cls.PATCH_SIZE
|
||||||
pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size
|
|
||||||
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
|
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
|
||||||
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
|
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
|
||||||
|
|
||||||
@ -150,8 +186,8 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
width = latent_width * 8
|
width = latent_width * 8
|
||||||
|
|
||||||
if pad_h > 0 or pad_w > 0:
|
if pad_h > 0 or pad_w > 0:
|
||||||
print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}")
|
logger.debug(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width} → {latent_height}x{latent_width}")
|
||||||
print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
|
logger.debug(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
|
||||||
|
|
||||||
# Collect entity prompts and masks
|
# Collect entity prompts and masks
|
||||||
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3]
|
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3]
|
||||||
@ -166,7 +202,7 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# Log warning if some entities were skipped
|
# Log warning if some entities were skipped
|
||||||
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
|
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
|
||||||
if len(valid_entities) < total_prompts_provided:
|
if len(valid_entities) < total_prompts_provided:
|
||||||
print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
|
logger.warning(f"[EliGen] Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
|
||||||
|
|
||||||
# If no valid entities, return standard conditioning
|
# If no valid entities, return standard conditioning
|
||||||
if len(valid_entities) == 0:
|
if len(valid_entities) == 0:
|
||||||
@ -200,7 +236,37 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# This is different from IMAGE type which is [batch, height, width, channels]
|
# This is different from IMAGE type which is [batch, height, width, channels]
|
||||||
mask_tensor = mask
|
mask_tensor = mask
|
||||||
|
|
||||||
# Log original mask dimensions
|
# Validate mask dtype
|
||||||
|
if mask_tensor.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
|
||||||
|
raise TypeError(
|
||||||
|
f"Entity {i+1} mask has invalid dtype {mask_tensor.dtype}. "
|
||||||
|
f"Expected float32, float16, or bfloat16. "
|
||||||
|
f"Ensure you're using LoadImageMask node, not LoadImage."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log original mask statistics
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen] Entity {i+1} input mask: shape={mask_tensor.shape}, "
|
||||||
|
f"dtype={mask_tensor.dtype}, min={mask_tensor.min():.4f}, max={mask_tensor.max():.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for all-zero masks (common error when wrong channel selected)
|
||||||
|
if mask_tensor.max() == 0.0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Entity {i+1} mask is all zeros! This usually means:\n"
|
||||||
|
f" 1. Wrong channel selected in LoadImageMask (use 'red', 'green', or 'blue', NOT 'alpha')\n"
|
||||||
|
f" 2. Your mask image is completely black\n"
|
||||||
|
f" 3. The mask file failed to load"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for constant masks (no variation)
|
||||||
|
if mask_tensor.min() == mask_tensor.max() and mask_tensor.max() > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"[EliGen] Entity {i+1} mask has no variation (all pixels = {mask_tensor.min():.4f}). "
|
||||||
|
f"This entity will affect the entire image."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract original dimensions
|
||||||
original_shape = mask_tensor.shape
|
original_shape = mask_tensor.shape
|
||||||
if len(original_shape) == 2:
|
if len(original_shape) == 2:
|
||||||
# [height, width] - single mask without batch
|
# [height, width] - single mask without batch
|
||||||
@ -211,7 +277,20 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# [batch, height, width] - standard MASK format
|
# [batch, height, width] - standard MASK format
|
||||||
orig_h, orig_w = original_shape[1], original_shape[2]
|
orig_h, orig_w = original_shape[1], original_shape[2]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected mask shape: {original_shape}. Expected [H, W] or [B, H, W]")
|
raise ValueError(
|
||||||
|
f"Entity {i+1} has unexpected mask shape: {original_shape}. "
|
||||||
|
f"Expected [H, W] or [B, H, W]. Got {len(original_shape)} dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log size mismatch if mask doesn't match expected latent dimensions
|
||||||
|
expected_h, expected_w = latent_height * 8, latent_width * 8
|
||||||
|
if orig_h != expected_h or orig_w != expected_w:
|
||||||
|
logger.info(
|
||||||
|
f"[EliGen] Entity {i+1} mask size mismatch: {orig_h}x{orig_w} vs expected {expected_h}x{expected_w}. "
|
||||||
|
f"Will resize to {latent_height}x{latent_width} latent space."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
|
||||||
|
|
||||||
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
|
# Convert MASK format [batch, height, width] to [batch, 1, height, width] for common_upscale
|
||||||
# common_upscale expects [batch, channels, height, width]
|
# common_upscale expects [batch, channels, height, width]
|
||||||
@ -233,17 +312,32 @@ class TextEncodeQwenImageEliGen(io.ComfyNode):
|
|||||||
# Log how many pixels are active in the mask
|
# Log how many pixels are active in the mask
|
||||||
active_pixels = (resized_mask > 0).sum().item()
|
active_pixels = (resized_mask > 0).sum().item()
|
||||||
total_pixels = resized_mask.numel()
|
total_pixels = resized_mask.numel()
|
||||||
|
coverage_pct = 100 * active_pixels / total_pixels if total_pixels > 0 else 0
|
||||||
|
|
||||||
|
if active_pixels == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Entity {i+1} mask has no active pixels after resizing to latent space! "
|
||||||
|
f"Original mask may have been too small or all black."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({coverage_pct:.1f}%)"
|
||||||
|
)
|
||||||
|
|
||||||
processed_entity_masks.append(resized_mask)
|
processed_entity_masks.append(resized_mask)
|
||||||
|
|
||||||
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
|
||||||
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
|
# Each item in processed_entity_masks has shape [1, 1, H, W] (batch=1, channel=1)
|
||||||
# We need to remove batch dim, stack, then add it back
|
# We need to remove batch dim, stack, then add it back
|
||||||
# Option 1: Squeeze batch dim from each mask
|
processed_entity_masks_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
|
||||||
processed_no_batch = [m.squeeze(0) for m in processed_entity_masks] # Each: [1, H, W]
|
entity_masks_tensor = torch.stack(processed_entity_masks_no_batch, dim=0) # [num_entities, 1, H, W]
|
||||||
entity_masks_tensor = torch.stack(processed_no_batch, dim=0) # [num_entities, 1, H, W]
|
|
||||||
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
|
entity_masks_tensor = entity_masks_tensor.unsqueeze(0) # [1, num_entities, 1, H, W]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen] Stacked {len(valid_entities)} entity masks into tensor: "
|
||||||
|
f"shape={entity_masks_tensor.shape} (expected: [1, {len(valid_entities)}, 1, {latent_height}, {latent_width}])"
|
||||||
|
)
|
||||||
|
|
||||||
# Extract global prompt embedding and mask from conditioning
|
# Extract global prompt embedding and mask from conditioning
|
||||||
# Conditioning format: [[cond_tensor, extra_dict]]
|
# Conditioning format: [[cond_tensor, extra_dict]]
|
||||||
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
|
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user