mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
replace QwenEmbedRope with existing ComfyUI rope
This commit is contained in:
parent
99a25a3dc4
commit
6c09121070
@ -16,125 +16,6 @@ import comfy.patcher_extension
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QwenEmbedRope(nn.Module):
|
||||
"""RoPE implementation for EliGen.
|
||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61
|
||||
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
|
||||
"""
|
||||
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.rope_cache = {}
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
def rope_params(self, index, dim, theta=10000):
|
||||
"""
|
||||
Args:
|
||||
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
||||
|
||||
Returns:
|
||||
Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb
|
||||
"""
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(
|
||||
index,
|
||||
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
|
||||
)
|
||||
# Convert to real-valued rotation matrix format (matches Flux rope() output)
|
||||
# Rotation matrix: [[cos, -sin], [sin, cos]]
|
||||
cos_freqs = torch.cos(freqs)
|
||||
sin_freqs = torch.sin(freqs)
|
||||
# Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2]
|
||||
out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1)
|
||||
out = out.reshape(*freqs.shape, 2, 2)
|
||||
return out
|
||||
|
||||
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
|
||||
_, height, width = video_fhw
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2)
|
||||
else:
|
||||
max_vid_index = max(height, width)
|
||||
required_len = max_vid_index + max(txt_seq_lens)
|
||||
cur_max_len = self.pos_freqs.shape[0]
|
||||
if required_len <= cur_max_len:
|
||||
return
|
||||
|
||||
new_max_len = math.ceil(required_len / 512) * 512
|
||||
pos_index = torch.arange(new_max_len)
|
||||
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat([
|
||||
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
self.neg_freqs = torch.cat([
|
||||
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
], dim=1)
|
||||
return
|
||||
|
||||
def forward(self, video_fhw, txt_seq_lens, device):
|
||||
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
|
||||
if self.pos_freqs.device != device:
|
||||
self.pos_freqs = self.pos_freqs.to(device)
|
||||
self.neg_freqs = self.neg_freqs.to(device)
|
||||
|
||||
vid_freqs = []
|
||||
max_vid_index = 0
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
rope_key = f"{idx}_{height}_{width}"
|
||||
|
||||
if rope_key not in self.rope_cache:
|
||||
seq_lens = frame * height * width
|
||||
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat(
|
||||
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
|
||||
)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2)
|
||||
self.rope_cache[rope_key] = freqs.clone().contiguous()
|
||||
vid_freqs.append(self.rope_cache[rope_key])
|
||||
|
||||
if self.scale_rope:
|
||||
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
||||
else:
|
||||
max_vid_index = max(height, width, max_vid_index)
|
||||
|
||||
max_len = max(txt_seq_lens)
|
||||
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -477,8 +358,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
# Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs)
|
||||
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
|
||||
|
||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim,
|
||||
@ -529,134 +408,109 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
||||
|
||||
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
|
||||
entity_prompt_emb_mask, entity_masks, height, width, image):
|
||||
entity_prompt_emb_mask, entity_masks, height, width, image,
|
||||
cond_or_uncond=None, batch_size=None):
|
||||
"""
|
||||
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434
|
||||
Process entity masks and build spatial attention mask for EliGen.
|
||||
|
||||
This method:
|
||||
1. Concatenates entity + global prompts
|
||||
2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder
|
||||
3. Creates attention mask enforcing spatial restrictions
|
||||
Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask
|
||||
enforcing spatial restrictions, and handles CFG batching with separate masks.
|
||||
|
||||
Args:
|
||||
latents: [B, 16, H, W]
|
||||
prompt_emb: [1, seq_len, 3584] - Global prompt
|
||||
prompt_emb_mask: [1, seq_len]
|
||||
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
|
||||
entity_prompt_emb_mask: List[[1, L_i]]
|
||||
entity_masks: [1, N, 1, H/8, W/8]
|
||||
height: int (padded pixel height)
|
||||
width: int (padded pixel width)
|
||||
image: [B, patches, 64] - Patchified latents
|
||||
|
||||
Returns:
|
||||
all_prompt_emb: [1, total_seq, 3584]
|
||||
image_rotary_emb: RoPE embeddings
|
||||
attention_mask: [1, 1, total_seq, total_seq]
|
||||
Based on: https://github.com/modelscope/DiffSynth-Studio
|
||||
"""
|
||||
num_entities = len(entity_prompt_emb)
|
||||
batch_size = latents.shape[0]
|
||||
actual_batch_size = latents.shape[0]
|
||||
|
||||
has_positive = cond_or_uncond and 0 in cond_or_uncond
|
||||
has_negative = cond_or_uncond and 1 in cond_or_uncond
|
||||
is_cfg_batched = has_positive and has_negative
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image "
|
||||
f"(latents: {latents.shape}, batch_size: {batch_size})"
|
||||
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, "
|
||||
f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}"
|
||||
)
|
||||
|
||||
# Validate batch consistency (all batches should have same sequence lengths)
|
||||
# This is a ComfyUI requirement - batched prompts must have uniform padding
|
||||
if batch_size > 1:
|
||||
logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility")
|
||||
|
||||
# SECTION 1: Concatenate entity + global prompts
|
||||
# Concatenate entity + global prompts
|
||||
all_prompt_emb = entity_prompt_emb + [prompt_emb]
|
||||
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
|
||||
all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
|
||||
|
||||
# SECTION 2: Build RoPE position embeddings
|
||||
# For EliGen, we create RoPE for ONE batch element's dimensions
|
||||
# The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim
|
||||
# Build RoPE embeddings
|
||||
patch_h = height // self.PATCH_TO_PIXEL_RATIO
|
||||
patch_w = width // self.PATCH_TO_PIXEL_RATIO
|
||||
|
||||
# Create RoPE for a single image (frame=1 for images, not video)
|
||||
# This will broadcast across all batch elements automatically
|
||||
img_shapes_single = [(1, patch_h, patch_w)]
|
||||
|
||||
# Calculate sequence lengths for entities and global prompt
|
||||
# Use [0] to get first batch element (all batches should have same sequence lengths)
|
||||
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
|
||||
|
||||
# Handle None case in ComfyUI (None means no padding, all tokens valid)
|
||||
if prompt_emb_mask is not None:
|
||||
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
|
||||
else:
|
||||
# No mask = no padding, use full sequence length
|
||||
global_seq_len = int(prompt_emb.shape[1])
|
||||
|
||||
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs))
|
||||
# We pass a single shape, not repeated for batch, because RoPE will broadcast
|
||||
txt_seq_lens = [global_seq_len]
|
||||
image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device)
|
||||
max_vid_index = max(patch_h // 2, patch_w // 2)
|
||||
|
||||
# Create SEPARATE RoPE embeddings for each entity
|
||||
# Each entity gets its own positional encoding based on its sequence length
|
||||
# We only need to create these once since they're the same for all batch elements
|
||||
entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1]
|
||||
for entity_seq_len in entity_seq_lens]
|
||||
# Generate per-entity text RoPE (each entity starts from same offset)
|
||||
entity_txt_embs = []
|
||||
for entity_seq_len in entity_seq_lens:
|
||||
entity_ids = torch.arange(
|
||||
max_vid_index,
|
||||
max_vid_index + entity_seq_len,
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
|
||||
# Concatenate entity RoPEs with global RoPE along sequence dimension
|
||||
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated
|
||||
# This creates the RoPE for ONE batch element's sequence
|
||||
# Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...]
|
||||
# and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically
|
||||
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
|
||||
entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0)
|
||||
entity_txt_embs.append(entity_rope)
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] RoPE created for single batch element - "
|
||||
f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} "
|
||||
f"(both will broadcast across batch_size={batch_size})"
|
||||
)
|
||||
# Generate global text RoPE
|
||||
global_ids = torch.arange(
|
||||
max_vid_index,
|
||||
max_vid_index + global_seq_len,
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0)
|
||||
|
||||
# Replace text part of tuple with concatenated entity + global RoPE
|
||||
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
|
||||
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0)
|
||||
|
||||
# SECTION 3: Prepare spatial masks
|
||||
repeat_dim = latents.shape[1] # 16 (latent channels)
|
||||
max_masks = entity_masks.shape[1] # N entities
|
||||
h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
|
||||
w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
|
||||
|
||||
img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device)
|
||||
img_ids[:, :, 0] = 0
|
||||
img_ids[:, :, 1] = h_coords.unsqueeze(1)
|
||||
img_ids[:, :, 2] = w_coords.unsqueeze(0)
|
||||
img_ids = img_ids.reshape(1, -1, 3)
|
||||
|
||||
img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0)
|
||||
|
||||
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||
|
||||
image_rotary_emb = (img_rope, txt_rotary_emb)
|
||||
|
||||
# Prepare spatial masks
|
||||
repeat_dim = latents.shape[1]
|
||||
max_masks = entity_masks.shape[1]
|
||||
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
|
||||
|
||||
# Pad masks to match padded latent dimensions
|
||||
# entity_masks shape: [1, N, 16, H/8, W/8]
|
||||
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
|
||||
padded_h = height // self.LATENT_TO_PIXEL_RATIO
|
||||
padded_w = width // self.LATENT_TO_PIXEL_RATIO
|
||||
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
|
||||
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
|
||||
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
|
||||
|
||||
# Pad each entity mask
|
||||
pad_h = padded_h - entity_masks.shape[3]
|
||||
pad_w = padded_w - entity_masks.shape[4]
|
||||
logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions")
|
||||
logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})")
|
||||
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
|
||||
|
||||
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
|
||||
|
||||
# Add global mask (all True) - must be same size as padded entity masks
|
||||
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
|
||||
device=latents.device, dtype=latents.dtype)
|
||||
entity_masks = entity_masks + [global_mask]
|
||||
|
||||
# SECTION 4: Patchify masks
|
||||
# Patchify masks
|
||||
N = len(entity_masks)
|
||||
batch_size = int(entity_masks[0].shape[0])
|
||||
seq_lens = entity_seq_lens + [global_seq_len]
|
||||
total_seq_len = int(sum(seq_lens) + image.shape[1])
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] Building attention mask: "
|
||||
f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})"
|
||||
)
|
||||
logger.debug(f"[EliGen Model] total_seq={total_seq_len}")
|
||||
|
||||
patched_masks = []
|
||||
for i in range(N):
|
||||
@ -694,7 +548,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
|
||||
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
|
||||
|
||||
# Always repeat mask to match image sequence length (matches DiffSynth line 480)
|
||||
# Always repeat mask to match image sequence length
|
||||
repeat_time = single_image_seq // image_mask.shape[-1]
|
||||
image_mask = image_mask.repeat(1, 1, repeat_time)
|
||||
|
||||
@ -713,12 +567,44 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
start_j, end_j = cumsum[j], cumsum[j+1]
|
||||
attention_mask[:, start_i:end_i, start_j:end_j] = False
|
||||
|
||||
# SECTION 6: Convert to additive bias
|
||||
# SECTION 6: Convert to additive bias and handle CFG batching
|
||||
attention_mask = attention_mask.float()
|
||||
num_valid_connections = (attention_mask == 1).sum().item()
|
||||
attention_mask[attention_mask == 0] = float('-inf')
|
||||
attention_mask[attention_mask == 1] = 0
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1)
|
||||
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype)
|
||||
|
||||
# Handle CFG batching: Create separate masks for positive and negative
|
||||
if is_cfg_batched and actual_batch_size > 1:
|
||||
# CFG batch: [positive, negative] - need different masks for each
|
||||
# Positive gets entity constraints, negative gets standard attention (all zeros)
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] CFG batched detected - creating separate masks. "
|
||||
f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask"
|
||||
)
|
||||
|
||||
# Create standard attention mask (all zeros = no constraints)
|
||||
standard_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# Stack masks according to cond_or_uncond order
|
||||
mask_list = []
|
||||
for cond_type in cond_or_uncond:
|
||||
if cond_type == 0: # Positive - use entity mask
|
||||
mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask
|
||||
else: # Negative - use standard mask
|
||||
mask_list.append(standard_mask[0:1])
|
||||
|
||||
# Concatenate masks to match batch
|
||||
attention_mask = torch.cat(mask_list, dim=0)
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. "
|
||||
f"Final shape: {attention_mask.shape}"
|
||||
)
|
||||
|
||||
# Add head dimension: [B, 1, seq, seq]
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
logger.debug(
|
||||
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
|
||||
@ -778,23 +664,28 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
# Extract entity data from kwargs
|
||||
# Extract EliGen entity data
|
||||
entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
|
||||
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
|
||||
entity_masks = kwargs.get("entity_masks", None)
|
||||
|
||||
# Branch: EliGen vs Standard path
|
||||
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
|
||||
# Detect batch composition for CFG handling
|
||||
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
|
||||
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative
|
||||
is_positive_cond = 0 in cond_or_uncond
|
||||
is_negative_cond = 1 in cond_or_uncond
|
||||
batch_size = x.shape[0]
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
logger.debug(
|
||||
f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, "
|
||||
f"has_positive={is_positive_cond}, has_negative={is_negative_cond}"
|
||||
)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
|
||||
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY)
|
||||
# orig_shape is from process_img which pads to patch_size
|
||||
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int)
|
||||
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
|
||||
# EliGen path
|
||||
height = int(orig_shape[-2] * 8)
|
||||
width = int(orig_shape[-1] * 8)
|
||||
|
||||
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
|
||||
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
|
||||
latents=x,
|
||||
prompt_emb=encoder_hidden_states,
|
||||
@ -804,22 +695,21 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
entity_masks=entity_masks,
|
||||
height=height,
|
||||
width=width,
|
||||
image=hidden_states
|
||||
image=hidden_states,
|
||||
cond_or_uncond=cond_or_uncond,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
# Apply image projection (text already processed in process_entity_masks)
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
# Store attention mask in transformer_options for the attention layers
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options["eligen_attention_mask"] = eligen_attention_mask
|
||||
|
||||
# Clean up
|
||||
del img_ids
|
||||
|
||||
else:
|
||||
# Standard path - existing code
|
||||
# Standard path
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user