From 81aa5a38b25a59dadb8b2765e08e277e044351a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 30 May 2026 17:53:37 -0700 Subject: [PATCH] Speed up ernie model by a bit on nvidia and use higher quality rope. (#14192) --- comfy/ldm/cosmos/predict2.py | 1 + comfy/ldm/ernie/model.py | 27 +++++++++++++-------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 30a36ad49..671fe834d 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -14,6 +14,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention import comfy.ldm.common_dit +import comfy.quant_ops # ---------------------- Feed Forward Network ----------------------- diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py index eba661aec..f158ca1d2 100644 --- a/comfy/ldm/ernie/model.py +++ b/comfy/ldm/ernie/model.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.quant_ops def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 @@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) return out.to(dtype=torch.float32, device=pos.device) -def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - rot_dim = freqs_cis.shape[-1] - x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] - cos_ = freqs_cis[0] - sin_ = freqs_cis[1] - x1, x2 = x.chunk(2, dim=-1) - x_rotated = torch.cat((-x2, x1), dim=-1) - return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) - class ErnieImageEmbedND3(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: tuple): super().__init__() @@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module): def forward(self, ids: torch.Tensor) -> torch.Tensor: emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) - emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] - return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + cos_ = emb[0] + sin_ = emb[1] + N = cos_.shape[-1] + half = N // 2 + cos_top = cos_[..., :half].repeat_interleave(2, dim=-1) + sin_top = sin_[..., :half].repeat_interleave(2, dim=-1) + cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1) + sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1) + rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1) + return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2) class ErnieImagePatchEmbedDynamic(nn.Module): def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): @@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module): key = self.norm_k(key) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb) q_flat = query.reshape(B, S, -1) k_flat = key.reshape(B, S, -1) @@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module): image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) - rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) del image_ids, text_ids sample = self.time_proj(timesteps).to(dtype)