diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 42553154e..45996e23b 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -16,125 +16,6 @@ import comfy.patcher_extension logger = logging.getLogger(__name__) -class QwenEmbedRope(nn.Module): - """RoPE implementation for EliGen. - https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L61 - Returns a tuple (img_freqs, txt_freqs) for separate image and text RoPE. - """ - def __init__(self, theta: int, axes_dim: list, scale_rope=False): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - pos_index = torch.arange(4096) - neg_index = torch.arange(4096).flip(0) * -1 - 1 - self.pos_freqs = torch.cat([ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], dim=1) - self.neg_freqs = torch.cat([ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], dim=1) - self.rope_cache = {} - self.scale_rope = scale_rope - - def rope_params(self, index, dim, theta=10000): - """ - Args: - index: [0, 1, 2, 3] 1D Tensor representing the position index of the token - - Returns: - Real-valued 2x2 rotation matrix format [..., 2, 2] compatible with apply_rotary_emb - """ - assert dim % 2 == 0 - freqs = torch.outer( - index, - 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)) - ) - # Convert to real-valued rotation matrix format (matches Flux rope() output) - # Rotation matrix: [[cos, -sin], [sin, cos]] - cos_freqs = torch.cos(freqs) - sin_freqs = torch.sin(freqs) - # Stack as rotation matrix: [cos, -sin, sin, cos] then reshape to [..., 2, 2] - out = torch.stack([cos_freqs, -sin_freqs, sin_freqs, cos_freqs], dim=-1) - out = out.reshape(*freqs.shape, 2, 2) - return out - - def _expand_pos_freqs_if_needed(self, video_fhw, txt_seq_lens): - if isinstance(video_fhw, list): - video_fhw = tuple(max([i[j] for i in video_fhw]) for j in range(3)) - _, height, width = video_fhw - if self.scale_rope: - max_vid_index = max(height // 2, width // 2) - else: - max_vid_index = max(height, width) - required_len = max_vid_index + max(txt_seq_lens) - cur_max_len = self.pos_freqs.shape[0] - if required_len <= cur_max_len: - return - - new_max_len = math.ceil(required_len / 512) * 512 - pos_index = torch.arange(new_max_len) - neg_index = torch.arange(new_max_len).flip(0) * -1 - 1 - self.pos_freqs = torch.cat([ - self.rope_params(pos_index, self.axes_dim[0], self.theta), - self.rope_params(pos_index, self.axes_dim[1], self.theta), - self.rope_params(pos_index, self.axes_dim[2], self.theta), - ], dim=1) - self.neg_freqs = torch.cat([ - self.rope_params(neg_index, self.axes_dim[0], self.theta), - self.rope_params(neg_index, self.axes_dim[1], self.theta), - self.rope_params(neg_index, self.axes_dim[2], self.theta), - ], dim=1) - return - - def forward(self, video_fhw, txt_seq_lens, device): - self._expand_pos_freqs_if_needed(video_fhw, txt_seq_lens) - if self.pos_freqs.device != device: - self.pos_freqs = self.pos_freqs.to(device) - self.neg_freqs = self.neg_freqs.to(device) - - vid_freqs = [] - max_vid_index = 0 - for idx, fhw in enumerate(video_fhw): - frame, height, width = fhw - rope_key = f"{idx}_{height}_{width}" - - if rope_key not in self.rope_cache: - seq_lens = frame * height * width - freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) - freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) - if self.scale_rope: - freqs_height = torch.cat( - [freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0 - ) - freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0) - freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1) - - else: - freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1) - freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1) - - freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1, 2, 2) - self.rope_cache[rope_key] = freqs.clone().contiguous() - vid_freqs.append(self.rope_cache[rope_key]) - - if self.scale_rope: - max_vid_index = max(height // 2, width // 2, max_vid_index) - else: - max_vid_index = max(height, width, max_vid_index) - - max_len = max(txt_seq_lens) - txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] - vid_freqs = torch.cat(vid_freqs, dim=0) - - return vid_freqs, txt_freqs - - class GELU(nn.Module): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): super().__init__() @@ -477,9 +358,7 @@ class QwenImageTransformer2DModel(nn.Module): self.inner_dim = num_attention_heads * attention_head_dim self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) - # Add research-accurate RoPE for EliGen (returns tuple of img_freqs, txt_freqs) - self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=[16, 56, 56], scale_rope=True) - + self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, @@ -529,134 +408,109 @@ class QwenImageTransformer2DModel(nn.Module): return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape def process_entity_masks(self, latents, prompt_emb, prompt_emb_mask, entity_prompt_emb, - entity_prompt_emb_mask, entity_masks, height, width, image): + entity_prompt_emb_mask, entity_masks, height, width, image, + cond_or_uncond=None, batch_size=None): """ - https://github.com/modelscope/DiffSynth-Studio/blob/538017177a6136f45f57cdf0b7c4e0d7e1f8b50d/diffsynth/models/qwen_image_dit.py#L434 Process entity masks and build spatial attention mask for EliGen. - This method: - 1. Concatenates entity + global prompts - 2. Builds RoPE embeddings for concatenated text using ComfyUI's pe_embedder - 3. Creates attention mask enforcing spatial restrictions + Concatenates entity+global prompts, builds RoPE embeddings, creates attention mask + enforcing spatial restrictions, and handles CFG batching with separate masks. - Args: - latents: [B, 16, H, W] - prompt_emb: [1, seq_len, 3584] - Global prompt - prompt_emb_mask: [1, seq_len] - entity_prompt_emb: List[[1, L_i, 3584]] - Entity prompts - entity_prompt_emb_mask: List[[1, L_i]] - entity_masks: [1, N, 1, H/8, W/8] - height: int (padded pixel height) - width: int (padded pixel width) - image: [B, patches, 64] - Patchified latents - - Returns: - all_prompt_emb: [1, total_seq, 3584] - image_rotary_emb: RoPE embeddings - attention_mask: [1, 1, total_seq, total_seq] + Based on: https://github.com/modelscope/DiffSynth-Studio """ num_entities = len(entity_prompt_emb) - batch_size = latents.shape[0] + actual_batch_size = latents.shape[0] + + has_positive = cond_or_uncond and 0 in cond_or_uncond + has_negative = cond_or_uncond and 1 in cond_or_uncond + is_cfg_batched = has_positive and has_negative + logger.debug( - f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px image " - f"(latents: {latents.shape}, batch_size: {batch_size})" + f"[EliGen Model] Processing {num_entities} entities for {height}x{width}px, " + f"batch_size={actual_batch_size}, CFG_batched={is_cfg_batched}" ) - # Validate batch consistency (all batches should have same sequence lengths) - # This is a ComfyUI requirement - batched prompts must have uniform padding - if batch_size > 1: - logger.debug(f"[EliGen Model] Batch size > 1 detected ({batch_size} batches), ensuring RoPE compatibility") - - # SECTION 1: Concatenate entity + global prompts + # Concatenate entity + global prompts all_prompt_emb = entity_prompt_emb + [prompt_emb] all_prompt_emb = [self.txt_in(self.txt_norm(local_prompt_emb)) for local_prompt_emb in all_prompt_emb] all_prompt_emb = torch.cat(all_prompt_emb, dim=1) - # SECTION 2: Build RoPE position embeddings - # For EliGen, we create RoPE for ONE batch element's dimensions - # The queries/keys have shape [batch, seq, heads, dim], and RoPE broadcasts across batch dim + # Build RoPE embeddings patch_h = height // self.PATCH_TO_PIXEL_RATIO patch_w = width // self.PATCH_TO_PIXEL_RATIO - # Create RoPE for a single image (frame=1 for images, not video) - # This will broadcast across all batch elements automatically - img_shapes_single = [(1, patch_h, patch_w)] - - # Calculate sequence lengths for entities and global prompt - # Use [0] to get first batch element (all batches should have same sequence lengths) entity_seq_lens = [int(mask.sum(dim=1)[0].item()) for mask in entity_prompt_emb_mask] - # Handle None case in ComfyUI (None means no padding, all tokens valid) if prompt_emb_mask is not None: global_seq_len = int(prompt_emb_mask.sum(dim=1)[0].item()) else: - # No mask = no padding, use full sequence length global_seq_len = int(prompt_emb.shape[1]) - # Get base image RoPE using global prompt length (returns tuple: (img_freqs, txt_freqs)) - # We pass a single shape, not repeated for batch, because RoPE will broadcast - txt_seq_lens = [global_seq_len] - image_rotary_emb = self.pos_embed(img_shapes_single, txt_seq_lens, device=latents.device) + max_vid_index = max(patch_h // 2, patch_w // 2) - # Create SEPARATE RoPE embeddings for each entity - # Each entity gets its own positional encoding based on its sequence length - # We only need to create these once since they're the same for all batch elements - entity_rotary_emb = [self.pos_embed([(1, patch_h, patch_w)], [entity_seq_len], device=latents.device)[1] - for entity_seq_len in entity_seq_lens] + # Generate per-entity text RoPE (each entity starts from same offset) + entity_txt_embs = [] + for entity_seq_len in entity_seq_lens: + entity_ids = torch.arange( + max_vid_index, + max_vid_index + entity_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) - # Concatenate entity RoPEs with global RoPE along sequence dimension - # Result: [entity1_seq, entity2_seq, ..., global_seq] concatenated - # This creates the RoPE for ONE batch element's sequence - # Note: We DON'T repeat for batch_size because the queries/keys have shape [batch, seq, ...] - # and PyTorch will broadcast the RoPE [seq, ...] across the batch dimension automatically - txt_rotary_emb = torch.cat(entity_rotary_emb + [image_rotary_emb[1]], dim=0) + entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0) + entity_txt_embs.append(entity_rope) - logger.debug( - f"[EliGen Model] RoPE created for single batch element - " - f"img: {image_rotary_emb[0].shape}, txt: {txt_rotary_emb.shape} " - f"(both will broadcast across batch_size={batch_size})" - ) + # Generate global text RoPE + global_ids = torch.arange( + max_vid_index, + max_vid_index + global_seq_len, + device=latents.device + ).reshape(1, -1, 1).repeat(1, 1, 3) + global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0) - # Replace text part of tuple with concatenated entity + global RoPE - image_rotary_emb = (image_rotary_emb[0], txt_rotary_emb) + txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0) - # SECTION 3: Prepare spatial masks - repeat_dim = latents.shape[1] # 16 (latent channels) - max_masks = entity_masks.shape[1] # N entities + h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device) + w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device) + + img_ids = torch.zeros((patch_h, patch_w, 3), device=latents.device) + img_ids[:, :, 0] = 0 + img_ids[:, :, 1] = h_coords.unsqueeze(1) + img_ids[:, :, 2] = w_coords.unsqueeze(0) + img_ids = img_ids.reshape(1, -1, 3) + + img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0) + + logger.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}") + + image_rotary_emb = (img_rope, txt_rotary_emb) + + # Prepare spatial masks + repeat_dim = latents.shape[1] + max_masks = entity_masks.shape[1] entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) - # Pad masks to match padded latent dimensions - # entity_masks shape: [1, N, 16, H/8, W/8] - # Need to pad to match orig_shape which is [B, 16, padded_H/8, padded_W/8] padded_h = height // self.LATENT_TO_PIXEL_RATIO padded_w = width // self.LATENT_TO_PIXEL_RATIO if entity_masks.shape[3] != padded_h or entity_masks.shape[4] != padded_w: - assert entity_masks.shape[3] <= padded_h and entity_masks.shape[4] <= padded_w, \ - f"Entity masks {entity_masks.shape[3]}x{entity_masks.shape[4]} larger than padded dims {padded_h}x{padded_w}" - - # Pad each entity mask pad_h = padded_h - entity_masks.shape[3] pad_w = padded_w - entity_masks.shape[4] - logger.debug(f"[EliGen Model] Padding entity masks by ({pad_h}, {pad_w}) to match latent dimensions") + logger.debug(f"[EliGen Model] Padding masks by ({pad_h}, {pad_w})") entity_masks = torch.nn.functional.pad(entity_masks, (0, pad_w, 0, pad_h), mode='constant', value=0) entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] - # Add global mask (all True) - must be same size as padded entity masks global_mask = torch.ones((entity_masks[0].shape[0], entity_masks[0].shape[1], padded_h, padded_w), device=latents.device, dtype=latents.dtype) entity_masks = entity_masks + [global_mask] - # SECTION 4: Patchify masks + # Patchify masks N = len(entity_masks) batch_size = int(entity_masks[0].shape[0]) seq_lens = entity_seq_lens + [global_seq_len] total_seq_len = int(sum(seq_lens) + image.shape[1]) - logger.debug( - f"[EliGen Model] Building attention mask: " - f"total_seq={total_seq_len} (entities: {entity_seq_lens}, global: {global_seq_len}, image: {image.shape[1]})" - ) + logger.debug(f"[EliGen Model] total_seq={total_seq_len}") patched_masks = [] for i in range(N): @@ -694,7 +548,7 @@ class QwenImageTransformer2DModel(nn.Module): image_mask = torch.sum(patched_masks[i], dim=-1) > 0 image_mask = image_mask.unsqueeze(1).repeat(1, seq_lens[i], 1) - # Always repeat mask to match image sequence length (matches DiffSynth line 480) + # Always repeat mask to match image sequence length repeat_time = single_image_seq // image_mask.shape[-1] image_mask = image_mask.repeat(1, 1, repeat_time) @@ -713,12 +567,44 @@ class QwenImageTransformer2DModel(nn.Module): start_j, end_j = cumsum[j], cumsum[j+1] attention_mask[:, start_i:end_i, start_j:end_j] = False - # SECTION 6: Convert to additive bias + # SECTION 6: Convert to additive bias and handle CFG batching attention_mask = attention_mask.float() num_valid_connections = (attention_mask == 1).sum().item() attention_mask[attention_mask == 0] = float('-inf') attention_mask[attention_mask == 1] = 0 - attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype).unsqueeze(1) + attention_mask = attention_mask.to(device=latents.device, dtype=latents.dtype) + + # Handle CFG batching: Create separate masks for positive and negative + if is_cfg_batched and actual_batch_size > 1: + # CFG batch: [positive, negative] - need different masks for each + # Positive gets entity constraints, negative gets standard attention (all zeros) + + logger.debug( + f"[EliGen Model] CFG batched detected - creating separate masks. " + f"Positive (index 0) gets entity mask, Negative (index 1) gets standard mask" + ) + + # Create standard attention mask (all zeros = no constraints) + standard_mask = torch.zeros_like(attention_mask) + + # Stack masks according to cond_or_uncond order + mask_list = [] + for cond_type in cond_or_uncond: + if cond_type == 0: # Positive - use entity mask + mask_list.append(attention_mask[0:1]) # Take first (and only) entity mask + else: # Negative - use standard mask + mask_list.append(standard_mask[0:1]) + + # Concatenate masks to match batch + attention_mask = torch.cat(mask_list, dim=0) + + logger.debug( + f"[EliGen Model] Created {len(mask_list)} masks for CFG batch. " + f"Final shape: {attention_mask.shape}" + ) + + # Add head dimension: [B, 1, seq, seq] + attention_mask = attention_mask.unsqueeze(1) logger.debug( f"[EliGen Model] Attention mask created: shape={attention_mask.shape}, " @@ -778,23 +664,28 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - # Extract entity data from kwargs + # Extract EliGen entity data entity_prompt_emb = kwargs.get("entity_prompt_emb", None) entity_prompt_emb_mask = kwargs.get("entity_prompt_emb_mask", None) entity_masks = kwargs.get("entity_masks", None) - # Branch: EliGen vs Standard path - # Only apply EliGen to POSITIVE conditioning (cond_or_uncond contains 0) + # Detect batch composition for CFG handling cond_or_uncond = transformer_options.get("cond_or_uncond", []) if transformer_options else [] - is_positive_cond = 0 in cond_or_uncond # 0 = conditional/positive, 1 = unconditional/negative + is_positive_cond = 0 in cond_or_uncond + is_negative_cond = 1 in cond_or_uncond + batch_size = x.shape[0] + + if entity_prompt_emb is not None: + logger.debug( + f"[EliGen Forward] batch_size={batch_size}, cond_or_uncond={cond_or_uncond}, " + f"has_positive={is_positive_cond}, has_negative={is_negative_cond}" + ) if entity_prompt_emb is not None and entity_masks is not None and entity_prompt_emb_mask is not None and is_positive_cond: - # EliGen path - process entity masks (POSITIVE CONDITIONING ONLY) - # orig_shape is from process_img which pads to patch_size - height = int(orig_shape[-2] * 8) # Padded latent height -> pixel height (ensure int) - width = int(orig_shape[-1] * 8) # Padded latent width -> pixel width (ensure int) + # EliGen path + height = int(orig_shape[-2] * 8) + width = int(orig_shape[-1] * 8) - # Call process_entity_masks to get concatenated text, RoPE, and attention mask encoder_hidden_states, image_rotary_emb, eligen_attention_mask = self.process_entity_masks( latents=x, prompt_emb=encoder_hidden_states, @@ -804,22 +695,21 @@ class QwenImageTransformer2DModel(nn.Module): entity_masks=entity_masks, height=height, width=width, - image=hidden_states + image=hidden_states, + cond_or_uncond=cond_or_uncond, + batch_size=batch_size ) - # Apply image projection (text already processed in process_entity_masks) hidden_states = self.img_in(hidden_states) - # Store attention mask in transformer_options for the attention layers if transformer_options is None: transformer_options = {} transformer_options["eligen_attention_mask"] = eligen_attention_mask - # Clean up del img_ids else: - # Standard path - existing code + # Standard path txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1)