replace QwenEmbedRope with existing ComfyUI rope

This commit is contained in:
nolan4 2025-10-27 20:19:12 -07:00
parent 99a25a3dc4
commit 6c09121070

View File

@ -16,125 +16,6 @@ import comfy.patcher_extension
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QwenEmbedRope(nn.Module):
"""RoPE implementation for EliGen.
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61
Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE.
"""
def __init__(self, theta: int, axes_dim: list, scale_rope=False):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
pos_index = torch.arange(4096)
neg_index = torch.arange(4096).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
self.rope_cache = {}
self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000):
"""
Args:
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
Returns:
Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb
"""
assert dim % 2 == 0
freqs = torch.outer(
index,
1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim))
)
# Convert to real-valued rotation matrix format (matches Flux rope() output)
# Rotation matrix: [[cos, -sin], [sin, cos]]
cos_freqs = torch.cos(freqs)
sin_freqs = torch.sin(freqs)
# Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2]
out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1)
out = out.reshape(*freqs.shape, 2, 2)
return out
def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens):
if isinstance(video_fhw, list):
video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3))
_, height, width = video_fhw
if self.scale_rope:
max_vid_index = max(height // 2, width // 2)
else:
max_vid_index = max(height, width)
required_len = max_vid_index + max(txt_seq_lens)
cur_max_len = self.pos_freqs.shape[0]
if required_len <= cur_max_len:
return
new_max_len = math.ceil(required_len / 512) * 512
pos_index = torch.arange(new_max_len)
neg_index = torch.arange(new_max_len).flip(0) * -1 - 1
self.pos_freqs = torch.cat([
self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta),
self.rope_params(pos_index, self.axes_dim[2], self.theta),
], dim=1)
self.neg_freqs = torch.cat([
self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta),
self.rope_params(neg_index, self.axes_dim[2], self.theta),
], dim=1)
return
def forward(self, video_fhw, txt_seq_lens, device):
self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens)
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
vid_freqs = []
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
if rope_key not in self.rope_cache:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope:
freqs_height = torch.cat(
[freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0
)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
else:
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2)
self.rope_cache[rope_key] = freqs.clone().contiguous()
vid_freqs.append(self.rope_cache[rope_key])
if self.scale_rope:
max_vid_index = max(height // 2, width // 2, max_vid_index)
else:
max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs
class GELU(nn.Module): class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -477,9 +358,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
# Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs)
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True)
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim, embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim, pooled_projection_dim=pooled_projection_dim,
@ -529,134 +408,109 @@ class QwenImageTransformer2DModel(nn.Module):
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb,
entity_prompt_emb_mask, entity_masks, height, width, image): entity_prompt_emb_mask, entity_masks, height, width, image,
cond_or_uncond=None, batch_size=None):
""" """
https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434
Process entity masks and build spatial attention mask for EliGen. Process entity masks and build spatial attention mask for EliGen.
This method: Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask
1. Concatenates entity + global prompts enforcing spatial restrictions, and handles CFG batching with separate masks.
2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder
3. Creates attention mask enforcing spatial restrictions
Args: Based on: https://github.com/modelscope/DiffSynth-Studio
latents: [B, 16, H, W]
prompt_emb: [1, seq_len, 3584] - Global prompt
prompt_emb_mask: [1, seq_len]
entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts
entity_prompt_emb_mask: List[[1, L_i]]
entity_masks: [1, N, 1, H/8, W/8]
height: int (padded pixel height)
width: int (padded pixel width)
image: [B, patches, 64] - Patchified latents
Returns:
all_prompt_emb: [1, total_seq, 3584]
image_rotary_emb: RoPE embeddings
attention_mask: [1, 1, total_seq, total_seq]
""" """
num_entities = len(entity_prompt_emb) num_entities = len(entity_prompt_emb)
batch_size = latents.shape[0] actual_batch_size = latents.shape[0]
has_positive = cond_or_uncond and 0 in cond_or_uncond
has_negative = cond_or_uncond and 1 in cond_or_uncond
is_cfg_batched = has_positive and has_negative
logger.debug( logger.debug(
f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image " f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, "
f"(latents: {latents.shape}, batch_size: {batch_size})" f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}"
) )
# Validate batch consistency (all batches should have same sequence lengths) # Concatenate entity + global prompts
# This is a ComfyUI requirement - batched prompts must have uniform padding
if batch_size > 1:
logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility")
# SECTION 1: Concatenate entity + global prompts
all_prompt_emb = entity_prompt_emb + [prompt_emb] all_prompt_emb = entity_prompt_emb + [prompt_emb]
all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb]
all_prompt_emb = torch.cat(all_prompt_emb, dim=1) all_prompt_emb = torch.cat(all_prompt_emb, dim=1)
# SECTION 2: Build RoPE position embeddings # Build RoPE embeddings
# For EliGen, we create RoPE for ONE batch element's dimensions
# The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim
patch_h = height // self.PATCH_TO_PIXEL_RATIO patch_h = height // self.PATCH_TO_PIXEL_RATIO
patch_w = width // self.PATCH_TO_PIXEL_RATIO patch_w = width // self.PATCH_TO_PIXEL_RATIO
# Create RoPE for a single image (frame=1 for images, not video)
# This will broadcast across all batch elements automatically
img_shapes_single = [(1, patch_h, patch_w)]
# Calculate sequence lengths for entities and global prompt
# Use [0] to get first batch element (all batches should have same sequence lengths)
entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask]
# Handle None case in ComfyUI (None means no padding, all tokens valid)
if prompt_emb_mask is not None: if prompt_emb_mask is not None:
global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item())
else: else:
# No mask = no padding, use full sequence length
global_seq_len = int(prompt_emb.shape[1]) global_seq_len = int(prompt_emb.shape[1])
# Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) max_vid_index = max(patch_h // 2, patch_w // 2)
# We pass a single shape, not repeated for batch, because RoPE will broadcast
txt_seq_lens = [global_seq_len]
image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device)
# Create SEPARATE RoPE embeddings for each entity # Generate per-entity text RoPE (each entity starts from same offset)
# Each entity gets its own positional encoding based on its sequence length entity_txt_embs = []
# We only need to create these once since they're the same for all batch elements for entity_seq_len in entity_seq_lens:
entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1] entity_ids = torch.arange(
for entity_seq_len in entity_seq_lens] max_vid_index,
max_vid_index + entity_seq_len,
device=latents.device
).reshape(1, -1, 1).repeat(1, 1, 3)
# Concatenate entity RoPEs with global RoPE along sequence dimension entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0)
# Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated entity_txt_embs.append(entity_rope)
# This creates the RoPE for ONE batch element's sequence
# Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...]
# and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically
txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0)
logger.debug( # Generate global text RoPE
f"[EliGen Model] RoPE created for single batch element - " global_ids = torch.arange(
f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} " max_vid_index,
f"(both will broadcast across batch_size={batch_size})" max_vid_index + global_seq_len,
) device=latents.device
).reshape(1, -1, 1).repeat(1, 1, 3)
global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0)
# Replace text part of tuple with concatenated entity + global RoPE txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0)
image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb)
# SECTION 3: Prepare spatial masks h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
repeat_dim = latents.shape[1] # 16 (latent channels) w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
max_masks = entity_masks.shape[1] # N entities
img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device)
img_ids[:, :, 0] = 0
img_ids[:, :, 1] = h_coords.unsqueeze(1)
img_ids[:, :, 2] = w_coords.unsqueeze(0)
img_ids = img_ids.reshape(1, -1, 3)
img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0)
logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
image_rotary_emb = (img_rope, txt_rotary_emb)
# Prepare spatial masks
repeat_dim = latents.shape[1]
max_masks = entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
# Pad masks to match padded latent dimensions
# entity_masks shape: [1, N, 16, H/8, W/8]
# Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8]
padded_h = height // self.LATENT_TO_PIXEL_RATIO padded_h = height // self.LATENT_TO_PIXEL_RATIO
padded_w = width // self.LATENT_TO_PIXEL_RATIO padded_w = width // self.LATENT_TO_PIXEL_RATIO
if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w:
assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \
f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}"
# Pad each entity mask
pad_h = padded_h - entity_masks.shape[3] pad_h = padded_h - entity_masks.shape[3]
pad_w = padded_w - entity_masks.shape[4] pad_w = padded_w - entity_masks.shape[4]
logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions") logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})")
entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# Add global mask (all True) - must be same size as padded entity masks
global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w), global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w),
device=latents.device, dtype=latents.dtype) device=latents.device, dtype=latents.dtype)
entity_masks = entity_masks + [global_mask] entity_masks = entity_masks + [global_mask]
# SECTION 4: Patchify masks # Patchify masks
N = len(entity_masks) N = len(entity_masks)
batch_size = int(entity_masks[0].shape[0]) batch_size = int(entity_masks[0].shape[0])
seq_lens = entity_seq_lens + [global_seq_len] seq_lens = entity_seq_lens + [global_seq_len]
total_seq_len = int(sum(seq_lens) + image.shape[1]) total_seq_len = int(sum(seq_lens) + image.shape[1])
logger.debug( logger.debug(f"[EliGen Model] total_seq={total_seq_len}")
f"[EliGen Model] Building attention mask: "
f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})"
)
patched_masks = [] patched_masks = []
for i in range(N): for i in range(N):
@ -694,7 +548,7 @@ class QwenImageTransformer2DModel(nn.Module):
image_mask = torch.sum(patched_masks[i], dim=-1) > 0 image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1)
# Always repeat mask to match image sequence length (matches DiffSynth line 480) # Always repeat mask to match image sequence length
repeat_time = single_image_seq // image_mask.shape[-1] repeat_time = single_image_seq // image_mask.shape[-1]
image_mask = image_mask.repeat(1, 1, repeat_time) image_mask = image_mask.repeat(1, 1, repeat_time)
@ -713,12 +567,44 @@ class QwenImageTransformer2DModel(nn.Module):
start_j, end_j = cumsum[j], cumsum[j+1] start_j, end_j = cumsum[j], cumsum[j+1]
attention_mask[:, start_i:end_i, start_j:end_j] = False attention_mask[:, start_i:end_i, start_j:end_j] = False
# SECTION 6: Convert to additive bias # SECTION 6: Convert to additive bias and handle CFG batching
attention_mask = attention_mask.float() attention_mask = attention_mask.float()
num_valid_connections = (attention_mask == 1).sum().item() num_valid_connections = (attention_mask == 1).sum().item()
attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0 attention_mask[attention_mask == 1] = 0
attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype)
# Handle CFG batching: Create separate masks for positive and negative
if is_cfg_batched and actual_batch_size > 1:
# CFG batch: [positive, negative] - need different masks for each
# Positive gets entity constraints, negative gets standard attention (all zeros)
logger.debug(
f"[EliGen Model] CFG batched detected - creating separate masks. "
f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask"
)
# Create standard attention mask (all zeros = no constraints)
standard_mask = torch.zeros_like(attention_mask)
# Stack masks according to cond_or_uncond order
mask_list = []
for cond_type in cond_or_uncond:
if cond_type == 0: # Positive - use entity mask
mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask
else: # Negative - use standard mask
mask_list.append(standard_mask[0:1])
# Concatenate masks to match batch
attention_mask = torch.cat(mask_list, dim=0)
logger.debug(
f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. "
f"Final shape: {attention_mask.shape}"
)
# Add head dimension: [B, 1, seq, seq]
attention_mask = attention_mask.unsqueeze(1)
logger.debug( logger.debug(
f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, "
@ -778,23 +664,28 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
# Extract entity data from kwargs # Extract EliGen entity data
entity_prompt_emb = kwargs.get("entity_prompt_emb", None) entity_prompt_emb = kwargs.get("entity_prompt_emb", None)
entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None)
entity_masks = kwargs.get("entity_masks", None) entity_masks = kwargs.get("entity_masks", None)
# Branch: EliGen vs Standard path # Detect batch composition for CFG handling
# Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0)
cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else []
is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative is_positive_cond = 0 in cond_or_uncond
is_negative_cond = 1 in cond_or_uncond
batch_size = x.shape[0]
if entity_prompt_emb is not None:
logger.debug(
f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, "
f"has_positive={is_positive_cond}, has_negative={is_negative_cond}"
)
if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond:
# EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) # EliGen path
# orig_shape is from process_img which pads to patch_size height = int(orig_shape[-2] * 8)
height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) width = int(orig_shape[-1] * 8)
width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int)
# Call process_entity_masks to get concatenated text, RoPE, and attention mask
encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks(
latents=x, latents=x,
prompt_emb=encoder_hidden_states, prompt_emb=encoder_hidden_states,
@ -804,22 +695,21 @@ class QwenImageTransformer2DModel(nn.Module):
entity_masks=entity_masks, entity_masks=entity_masks,
height=height, height=height,
width=width, width=width,
image=hidden_states image=hidden_states,
cond_or_uncond=cond_or_uncond,
batch_size=batch_size
) )
# Apply image projection (text already processed in process_entity_masks)
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
# Store attention mask in transformer_options for the attention layers
if transformer_options is None: if transformer_options is None:
transformer_options = {} transformer_options = {}
transformer_options["eligen_attention_mask"] = eligen_attention_mask transformer_options["eligen_attention_mask"] = eligen_attention_mask
# Clean up
del img_ids del img_ids
else: else:
# Standard path - existing code # Standard path
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)