From 7a00454fbdb4860bad6d949bac0d166e276e1c3a Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Tue, 24 Feb 2026 16:00:33 +0200 Subject: [PATCH] Simplify RoPE scaling --- comfy/ldm/wan/model.py | 38 +++++++++----------------------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index f1ae2e896..730ecc28f 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1700,38 +1700,18 @@ class SCAILWanModel(WanModel): return main_freqs F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] - downscale = H_pose != h - # when using pose downscaling, encode at the actual resolution so the freq space covers the right range, then pool back down below - pose_H_virtual = H_pose * 2 if downscale else H_pose - pose_W_virtual = W_pose * 2 if downscale else W_pose - pose_transformer_options = {"rope_options": {"shift_x": 120.0}} # pose frames use a fixed w-offset of 120 to spatially separate them from the main frames - pose_freqs = super().rope_encode(F_pose, pose_H_virtual, pose_W_virtual, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options) + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose - freqs = torch.cat([main_freqs, pose_freqs], dim=1) + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options) - # downsample pose frequencies to match actual pose input resolution - if downscale: - pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0]) - pose_h_len_full = (((H_pose * 2) + (self.patch_size[1] // 2)) // self.patch_size[1]) - pose_w_len_full = (((W_pose * 2) + (self.patch_size[2] // 2)) // self.patch_size[2]) - pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1]) - pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2]) - - pose_start_idx = freqs.shape[1] - pose_f_len_full * pose_h_len_full * pose_w_len_full - main_freqs, pose_freqs = freqs[:, :pose_start_idx], freqs[:, pose_start_idx:] - - B, _, heads, dim, _, _ = pose_freqs.shape - # Reshape and pool: (B, L, heads, dim, 2, 2) -> pool H,W -> (B, L', heads, dim, 2, 2) - pose_freqs = pose_freqs.reshape(B, pose_f_len_full, pose_h_len_full, pose_w_len_full, heads, dim, 2, 2) - pose_freqs = pose_freqs.permute(0, 1, 4, 5, 6, 7, 2, 3).reshape(-1, pose_h_len_full, pose_w_len_full) - pose_freqs = torch.nn.functional.avg_pool2d(pose_freqs, kernel_size=2, stride=2) - pose_freqs = pose_freqs.reshape(B, pose_f_len_full, heads, dim, 2, 2, pose_h_len_actual, pose_w_len_actual) - pose_freqs = pose_freqs.permute(0, 1, 6, 7, 2, 3, 4, 5).reshape(B, -1, heads, dim, 2, 2) - - freqs = torch.cat([main_freqs, pose_freqs], dim=1) - - return freqs + return torch.cat([main_freqs, pose_freqs], dim=1) def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): bs, c, t, h, w = x.shape