mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-12 01:07:30 +08:00
Simplify RoPE scaling
This commit is contained in:
parent
d9927cdebd
commit
7a00454fbd
@ -1700,38 +1700,18 @@ class SCAILWanModel(WanModel):
|
||||
return main_freqs
|
||||
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
downscale = H_pose != h
|
||||
# when using pose downscaling, encode at the actual resolution so the freq space covers the right range, then pool back down below
|
||||
pose_H_virtual = H_pose * 2 if downscale else H_pose
|
||||
pose_W_virtual = W_pose * 2 if downscale else W_pose
|
||||
|
||||
pose_transformer_options = {"rope_options": {"shift_x": 120.0}} # pose frames use a fixed w-offset of 120 to spatially separate them from the main frames
|
||||
pose_freqs = super().rope_encode(F_pose, pose_H_virtual, pose_W_virtual, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
|
||||
freqs = torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
# 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options)
|
||||
|
||||
# downsample pose frequencies to match actual pose input resolution
|
||||
if downscale:
|
||||
pose_f_len_full = ((F_pose + (self.patch_size[0] // 2)) // self.patch_size[0])
|
||||
pose_h_len_full = (((H_pose * 2) + (self.patch_size[1] // 2)) // self.patch_size[1])
|
||||
pose_w_len_full = (((W_pose * 2) + (self.patch_size[2] // 2)) // self.patch_size[2])
|
||||
pose_h_len_actual = ((H_pose + (self.patch_size[1] // 2)) // self.patch_size[1])
|
||||
pose_w_len_actual = ((W_pose + (self.patch_size[2] // 2)) // self.patch_size[2])
|
||||
|
||||
pose_start_idx = freqs.shape[1] - pose_f_len_full * pose_h_len_full * pose_w_len_full
|
||||
main_freqs, pose_freqs = freqs[:, :pose_start_idx], freqs[:, pose_start_idx:]
|
||||
|
||||
B, _, heads, dim, _, _ = pose_freqs.shape
|
||||
# Reshape and pool: (B, L, heads, dim, 2, 2) -> pool H,W -> (B, L', heads, dim, 2, 2)
|
||||
pose_freqs = pose_freqs.reshape(B, pose_f_len_full, pose_h_len_full, pose_w_len_full, heads, dim, 2, 2)
|
||||
pose_freqs = pose_freqs.permute(0, 1, 4, 5, 6, 7, 2, 3).reshape(-1, pose_h_len_full, pose_w_len_full)
|
||||
pose_freqs = torch.nn.functional.avg_pool2d(pose_freqs, kernel_size=2, stride=2)
|
||||
pose_freqs = pose_freqs.reshape(B, pose_f_len_full, heads, dim, 2, 2, pose_h_len_actual, pose_w_len_actual)
|
||||
pose_freqs = pose_freqs.permute(0, 1, 6, 7, 2, 3, 4, 5).reshape(B, -1, heads, dim, 2, 2)
|
||||
|
||||
freqs = torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
return freqs
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
Loading…
Reference in New Issue
Block a user