diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 05e55cc14..1db1536a4 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -1628,9 +1628,8 @@ class SCAILWanModel(WanModel): self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, **kwargs): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): - reference_latent = kwargs.get("reference_latent", None) if reference_latent is not None: x = torch.cat((reference_latent, x), dim=2) @@ -1726,8 +1725,10 @@ class SCAILWanModel(WanModel): x = torch.cat([x, time_dim_concat], dim=2) t_len = x.shape[2] + reference_latent = None if "reference_latent" in kwargs: - t_len += 1 + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, **kwargs)[:, :, :t, :h, :w] + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/model_base.py b/comfy/model_base.py index 4821dbb12..025c7fb7d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1498,7 +1498,7 @@ class WAN21_SCAIL(WAN21): return out def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs): - if t >= self.model_sampling.percent_to_sigma(pose_start) or t <= self.model_sampling.percent_to_sigma(pose_end): + if t[0] >= self.model_sampling.percent_to_sigma(pose_start) or t[0] <= self.model_sampling.percent_to_sigma(pose_end): kwargs.pop("pose_latents", None) return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -1511,11 +1511,11 @@ class WAN21_SCAIL(WAN21): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: - out['pose_latents'] = pose_latents.shape + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] return out