diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 1db1536a4..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1692,12 +1692,16 @@ class SCAILWanModel(WanModel): return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, transformer_options={}): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) if pose_latents is None: return main_freqs + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames @@ -1708,7 +1712,7 @@ class SCAILWanModel(WanModel): h_shift = (h_scale - 1) / 2 w_shift = (w_scale - 1) / 2 pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} - pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start, device=device, dtype=dtype, transformer_options=pose_transformer_options) + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) return torch.cat([main_freqs, pose_freqs], dim=1) @@ -1730,5 +1734,5 @@ class SCAILWanModel(WanModel): reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) t_len += reference_latent.shape[2] - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]