mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
replace QwenEmbedRope with existing ComfyUI rope
This commit is contained in:
parent
99a25a3dc4
commit
6c09121070
@ -16,125 +16,6 @@ import comfy.patcher_extension
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class QwenEmbedRope(nn.Module):
|
|
||||||
"""RoPE implementation for EliGen.
|
|
||||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61
|
|
||||||
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
|
|
||||||
"""
|
|
||||||
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
|
|
||||||
super().__init__()
|
|
||||||
self.theta = theta
|
|
||||||
self.axes_dim = axes_dim
|
|
||||||
pos_index = torch.arange(4096)
|
|
||||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
|
||||||
self.pos_freqs = torch.cat([
|
|
||||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
||||||
], dim=1)
|
|
||||||
self.neg_freqs = torch.cat([
|
|
||||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
||||||
], dim=1)
|
|
||||||
self.rope_cache = {}
|
|
||||||
self.scale_rope = scale_rope
|
|
||||||
|
|
||||||
def rope_params(self, index, dim, theta=10000):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb
|
|
||||||
"""
|
|
||||||
assert dim % 2 == 0
|
|
||||||
freqs = torch.outer(
|
|
||||||
index,
|
|
||||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
|
||||||
)
|
|
||||||
# Convert to real-valued rotation matrix format (matches Flux rope() output)
|
|
||||||
# Rotation matrix: [[cos, -sin], [sin, cos]]
|
|
||||||
cos_freqs = torch.cos(freqs)
|
|
||||||
sin_freqs = torch.sin(freqs)
|
|
||||||
# Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2]
|
|
||||||
out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1)
|
|
||||||
out = out.reshape(*freqs.shape, 2, 2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
|
||||||
if isinstance(video_fhw, list):
|
|
||||||
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
|
||||||
_, height, width = video_fhw
|
|
||||||
if self.scale_rope:
|
|
||||||
max_vid_index = max(height // 2, width // 2)
|
|
||||||
else:
|
|
||||||
max_vid_index = max(height, width)
|
|
||||||
required_len = max_vid_index + max(txt_seq_lens)
|
|
||||||
cur_max_len = self.pos_freqs.shape[0]
|
|
||||||
if required_len <= cur_max_len:
|
|
||||||
return
|
|
||||||
|
|
||||||
new_max_len = math.ceil(required_len / 512) * 512
|
|
||||||
pos_index = torch.arange(new_max_len)
|
|
||||||
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
|
||||||
self.pos_freqs = torch.cat([
|
|
||||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
|
||||||
], dim=1)
|
|
||||||
self.neg_freqs = torch.cat([
|
|
||||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
|
||||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
|
||||||
], dim=1)
|
|
||||||
return
|
|
||||||
|
|
||||||
def forward(self, video_fhw, txt_seq_lens, device):
|
|
||||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
|
||||||
if self.pos_freqs.device != device:
|
|
||||||
self.pos_freqs = self.pos_freqs.to(device)
|
|
||||||
self.neg_freqs = self.neg_freqs.to(device)
|
|
||||||
|
|
||||||
vid_freqs = []
|
|
||||||
max_vid_index = 0
|
|
||||||
for idx, fhw in enumerate(video_fhw):
|
|
||||||
frame, height, width = fhw
|
|
||||||
rope_key = f"{idx}_{height}_{width}"
|
|
||||||
|
|
||||||
if rope_key not in self.rope_cache:
|
|
||||||
seq_lens = frame * height * width
|
|
||||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
|
||||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
|
||||||
if self.scale_rope:
|
|
||||||
freqs_height = torch.cat(
|
|
||||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
|
||||||
)
|
|
||||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
|
||||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
|
|
||||||
else:
|
|
||||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
|
||||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
|
||||||
|
|
||||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2)
|
|
||||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
|
||||||
vid_freqs.append(self.rope_cache[rope_key])
|
|
||||||
|
|
||||||
if self.scale_rope:
|
|
||||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
|
||||||
else:
|
|
||||||
max_vid_index = max(height, width, max_vid_index)
|
|
||||||
|
|
||||||
max_len = max(txt_seq_lens)
|
|
||||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
|
||||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
|
||||||
|
|
||||||
return vid_freqs, txt_freqs
|
|
||||||
|
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -477,9 +358,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
# Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs)
|
|
||||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
|
|
||||||
|
|
||||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim,
|
embedding_dim=self.inner_dim,
|
||||||
pooled_projection_dim=pooled_projection_dim,
|
pooled_projection_dim=pooled_projection_dim,
|
||||||
@ -529,134 +408,109 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
||||||
entity_prompt_emb_mask, entity_masks, height, width, image):
|
entity_prompt_emb_mask, entity_masks, height, width, image,
|
||||||
|
cond_or_uncond=None, batch_size=None):
|
||||||
"""
|
"""
|
||||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434
|
|
||||||
Process entity masks and build spatial attention mask for EliGen.
|
Process entity masks and build spatial attention mask for EliGen.
|
||||||
|
|
||||||
This method:
|
Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask
|
||||||
1. Concatenates entity + global prompts
|
enforcing spatial restrictions, and handles CFG batching with separate masks.
|
||||||
2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder
|
|
||||||
3. Creates attention mask enforcing spatial restrictions
|
|
||||||
|
|
||||||
Args:
|
Based on: https://github.com/modelscope/DiffSynth-Studio
|
||||||
latents: [B, 16, H, W]
|
|
||||||
prompt_emb: [1, seq_len, 3584] - Global prompt
|
|
||||||
prompt_emb_mask: [1, seq_len]
|
|
||||||
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
|
|
||||||
entity_prompt_emb_mask: List[[1, L_i]]
|
|
||||||
entity_masks: [1, N, 1, H/8, W/8]
|
|
||||||
height: int (padded pixel height)
|
|
||||||
width: int (padded pixel width)
|
|
||||||
image: [B, patches, 64] - Patchified latents
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
all_prompt_emb: [1, total_seq, 3584]
|
|
||||||
image_rotary_emb: RoPE embeddings
|
|
||||||
attention_mask: [1, 1, total_seq, total_seq]
|
|
||||||
"""
|
"""
|
||||||
num_entities = len(entity_prompt_emb)
|
num_entities = len(entity_prompt_emb)
|
||||||
batch_size = latents.shape[0]
|
actual_batch_size = latents.shape[0]
|
||||||
|
|
||||||
|
has_positive = cond_or_uncond and 0 in cond_or_uncond
|
||||||
|
has_negative = cond_or_uncond and 1 in cond_or_uncond
|
||||||
|
is_cfg_batched = has_positive and has_negative
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image "
|
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, "
|
||||||
f"(latents: {latents.shape}, batch_size: {batch_size})"
|
f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate batch consistency (all batches should have same sequence lengths)
|
# Concatenate entity + global prompts
|
||||||
# This is a ComfyUI requirement - batched prompts must have uniform padding
|
|
||||||
if batch_size > 1:
|
|
||||||
logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility")
|
|
||||||
|
|
||||||
# SECTION 1: Concatenate entity + global prompts
|
|
||||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||||
|
|
||||||
# SECTION 2: Build RoPE position embeddings
|
# Build RoPE embeddings
|
||||||
# For EliGen, we create RoPE for ONE batch element's dimensions
|
|
||||||
# The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim
|
|
||||||
patch_h = height // self.PATCH_TO_PIXEL_RATIO
|
patch_h = height // self.PATCH_TO_PIXEL_RATIO
|
||||||
patch_w = width // self.PATCH_TO_PIXEL_RATIO
|
patch_w = width // self.PATCH_TO_PIXEL_RATIO
|
||||||
|
|
||||||
# Create RoPE for a single image (frame=1 for images, not video)
|
|
||||||
# This will broadcast across all batch elements automatically
|
|
||||||
img_shapes_single = [(1, patch_h, patch_w)]
|
|
||||||
|
|
||||||
# Calculate sequence lengths for entities and global prompt
|
|
||||||
# Use [0] to get first batch element (all batches should have same sequence lengths)
|
|
||||||
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
|
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
|
||||||
|
|
||||||
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
|
||||||
if prompt_emb_mask is not None:
|
if prompt_emb_mask is not None:
|
||||||
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
|
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
|
||||||
else:
|
else:
|
||||||
# No mask = no padding, use full sequence length
|
|
||||||
global_seq_len = int(prompt_emb.shape[1])
|
global_seq_len = int(prompt_emb.shape[1])
|
||||||
|
|
||||||
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
max_vid_index = max(patch_h // 2, patch_w // 2)
|
||||||
# We pass a single shape, not repeated for batch, because RoPE will broadcast
|
|
||||||
txt_seq_lens = [global_seq_len]
|
|
||||||
image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device)
|
|
||||||
|
|
||||||
# Create SEPARATE RoPE embeddings for each entity
|
# Generate per-entity text RoPE (each entity starts from same offset)
|
||||||
# Each entity gets its own positional encoding based on its sequence length
|
entity_txt_embs = []
|
||||||
# We only need to create these once since they're the same for all batch elements
|
for entity_seq_len in entity_seq_lens:
|
||||||
entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1]
|
entity_ids = torch.arange(
|
||||||
for entity_seq_len in entity_seq_lens]
|
max_vid_index,
|
||||||
|
max_vid_index + entity_seq_len,
|
||||||
|
device=latents.device
|
||||||
|
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||||
|
|
||||||
# Concatenate entity RoPEs with global RoPE along sequence dimension
|
entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0)
|
||||||
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
|
entity_txt_embs.append(entity_rope)
|
||||||
# This creates the RoPE for ONE batch element's sequence
|
|
||||||
# Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...]
|
|
||||||
# and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically
|
|
||||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
|
||||||
|
|
||||||
logger.debug(
|
# Generate global text RoPE
|
||||||
f"[EliGen Model] RoPE created for single batch element - "
|
global_ids = torch.arange(
|
||||||
f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} "
|
max_vid_index,
|
||||||
f"(both will broadcast across batch_size={batch_size})"
|
max_vid_index + global_seq_len,
|
||||||
)
|
device=latents.device
|
||||||
|
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||||
|
global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0)
|
||||||
|
|
||||||
# Replace text part of tuple with concatenated entity + global RoPE
|
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0)
|
||||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
|
||||||
|
|
||||||
# SECTION 3: Prepare spatial masks
|
h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
|
||||||
repeat_dim = latents.shape[1] # 16 (latent channels)
|
w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
|
||||||
max_masks = entity_masks.shape[1] # N entities
|
|
||||||
|
img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device)
|
||||||
|
img_ids[:, :, 0] = 0
|
||||||
|
img_ids[:, :, 1] = h_coords.unsqueeze(1)
|
||||||
|
img_ids[:, :, 2] = w_coords.unsqueeze(0)
|
||||||
|
img_ids = img_ids.reshape(1, -1, 3)
|
||||||
|
|
||||||
|
img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0)
|
||||||
|
|
||||||
|
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||||
|
|
||||||
|
image_rotary_emb = (img_rope, txt_rotary_emb)
|
||||||
|
|
||||||
|
# Prepare spatial masks
|
||||||
|
repeat_dim = latents.shape[1]
|
||||||
|
max_masks = entity_masks.shape[1]
|
||||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||||
|
|
||||||
# Pad masks to match padded latent dimensions
|
|
||||||
# entity_masks shape: [1, N, 16, H/8, W/8]
|
|
||||||
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
|
||||||
padded_h = height // self.LATENT_TO_PIXEL_RATIO
|
padded_h = height // self.LATENT_TO_PIXEL_RATIO
|
||||||
padded_w = width // self.LATENT_TO_PIXEL_RATIO
|
padded_w = width // self.LATENT_TO_PIXEL_RATIO
|
||||||
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||||
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
|
||||||
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
|
||||||
|
|
||||||
# Pad each entity mask
|
|
||||||
pad_h = padded_h - entity_masks.shape[3]
|
pad_h = padded_h - entity_masks.shape[3]
|
||||||
pad_w = padded_w - entity_masks.shape[4]
|
pad_w = padded_w - entity_masks.shape[4]
|
||||||
logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions")
|
logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})")
|
||||||
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||||
|
|
||||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||||
|
|
||||||
# Add global mask (all True) - must be same size as padded entity masks
|
|
||||||
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
|
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
|
||||||
device=latents.device, dtype=latents.dtype)
|
device=latents.device, dtype=latents.dtype)
|
||||||
entity_masks = entity_masks + [global_mask]
|
entity_masks = entity_masks + [global_mask]
|
||||||
|
|
||||||
# SECTION 4: Patchify masks
|
# Patchify masks
|
||||||
N = len(entity_masks)
|
N = len(entity_masks)
|
||||||
batch_size = int(entity_masks[0].shape[0])
|
batch_size = int(entity_masks[0].shape[0])
|
||||||
seq_lens = entity_seq_lens + [global_seq_len]
|
seq_lens = entity_seq_lens + [global_seq_len]
|
||||||
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"[EliGen Model] total_seq={total_seq_len}")
|
||||||
f"[EliGen Model] Building attention mask: "
|
|
||||||
f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})"
|
|
||||||
)
|
|
||||||
|
|
||||||
patched_masks = []
|
patched_masks = []
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
@ -694,7 +548,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||||
|
|
||||||
# Always repeat mask to match image sequence length (matches DiffSynth line 480)
|
# Always repeat mask to match image sequence length
|
||||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||||
|
|
||||||
@ -713,12 +567,44 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||||
|
|
||||||
# SECTION 6: Convert to additive bias
|
# SECTION 6: Convert to additive bias and handle CFG batching
|
||||||
attention_mask = attention_mask.float()
|
attention_mask = attention_mask.float()
|
||||||
num_valid_connections = (attention_mask == 1).sum().item()
|
num_valid_connections = (attention_mask == 1).sum().item()
|
||||||
attention_mask[attention_mask == 0] = float('-inf')
|
attention_mask[attention_mask == 0] = float('-inf')
|
||||||
attention_mask[attention_mask == 1] = 0
|
attention_mask[attention_mask == 1] = 0
|
||||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype)
|
||||||
|
|
||||||
|
# Handle CFG batching: Create separate masks for positive and negative
|
||||||
|
if is_cfg_batched and actual_batch_size > 1:
|
||||||
|
# CFG batch: [positive, negative] - need different masks for each
|
||||||
|
# Positive gets entity constraints, negative gets standard attention (all zeros)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] CFG batched detected - creating separate masks. "
|
||||||
|
f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create standard attention mask (all zeros = no constraints)
|
||||||
|
standard_mask = torch.zeros_like(attention_mask)
|
||||||
|
|
||||||
|
# Stack masks according to cond_or_uncond order
|
||||||
|
mask_list = []
|
||||||
|
for cond_type in cond_or_uncond:
|
||||||
|
if cond_type == 0: # Positive - use entity mask
|
||||||
|
mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask
|
||||||
|
else: # Negative - use standard mask
|
||||||
|
mask_list.append(standard_mask[0:1])
|
||||||
|
|
||||||
|
# Concatenate masks to match batch
|
||||||
|
attention_mask = torch.cat(mask_list, dim=0)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. "
|
||||||
|
f"Final shape: {attention_mask.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add head dimension: [B, 1, seq, seq]
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
|
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
|
||||||
@ -778,23 +664,28 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||||
|
|
||||||
# Extract entity data from kwargs
|
# Extract EliGen entity data
|
||||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||||
entity_masks = kwargs.get("entity_masks", None)
|
entity_masks = kwargs.get("entity_masks", None)
|
||||||
|
|
||||||
# Branch: EliGen vs Standard path
|
# Detect batch composition for CFG handling
|
||||||
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
|
|
||||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||||
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
|
is_positive_cond = 0 in cond_or_uncond
|
||||||
|
is_negative_cond = 1 in cond_or_uncond
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
|
||||||
|
if entity_prompt_emb is not None:
|
||||||
|
logger.debug(
|
||||||
|
f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, "
|
||||||
|
f"has_positive={is_positive_cond}, has_negative={is_negative_cond}"
|
||||||
|
)
|
||||||
|
|
||||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||||
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
|
# EliGen path
|
||||||
# orig_shape is from process_img which pads to patch_size
|
height = int(orig_shape[-2] * 8)
|
||||||
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
|
width = int(orig_shape[-1] * 8)
|
||||||
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
|
|
||||||
|
|
||||||
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
|
|
||||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||||
latents=x,
|
latents=x,
|
||||||
prompt_emb=encoder_hidden_states,
|
prompt_emb=encoder_hidden_states,
|
||||||
@ -804,22 +695,21 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
entity_masks=entity_masks,
|
entity_masks=entity_masks,
|
||||||
height=height,
|
height=height,
|
||||||
width=width,
|
width=width,
|
||||||
image=hidden_states
|
image=hidden_states,
|
||||||
|
cond_or_uncond=cond_or_uncond,
|
||||||
|
batch_size=batch_size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply image projection (text already processed in process_entity_masks)
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|
||||||
# Store attention mask in transformer_options for the attention layers
|
|
||||||
if transformer_options is None:
|
if transformer_options is None:
|
||||||
transformer_options = {}
|
transformer_options = {}
|
||||||
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
||||||
|
|
||||||
# Clean up
|
|
||||||
del img_ids
|
del img_ids
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Standard path - existing code
|
# Standard path
|
||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user