diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..30a36ad49 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -15,15 +15,6 @@ import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention import comfy.ldm.common_dit -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out - # ---------------------- Feed Forward Network ----------------------- class GPT2FeedForward(nn.Module): @@ -173,8 +164,7 @@ class Attention(nn.Module): k = self.k_norm(k) v = self.v_norm(v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) + q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb) return q, k, v q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) diff --git a/requirements.txt b/requirements.txt index 0617667e1..7ae16eb5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,7 +22,7 @@ alembic SQLAlchemy>=2.0.0 filelock av>=16.0.0 -comfy-kitchen==0.2.9 +comfy-kitchen==0.2.10 comfy-aimdo==0.4.5 requests simpleeval>=1.0.0