Minor adjustments/safeguards

This commit is contained in:
kijai 2026-02-24 21:02:41 +02:00
parent f927863377
commit e5bb1fe22a
2 changed files with 8 additions and 7 deletions

View File

@ -1628,9 +1628,8 @@ class SCAILWanModel(WanModel):
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, **kwargs):
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
reference_latent = kwargs.get("reference_latent", None)
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
@ -1726,8 +1725,10 @@ class SCAILWanModel(WanModel):
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
reference_latent = None
if "reference_latent" in kwargs:
t_len += 1
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, **kwargs)[:, :, :t, :h, :w]
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]

View File

@ -1498,7 +1498,7 @@ class WAN21_SCAIL(WAN21):
return out
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs):
if t >= self.model_sampling.percent_to_sigma(pose_start) or t <= self.model_sampling.percent_to_sigma(pose_end):
if t[0] >= self.model_sampling.percent_to_sigma(pose_start) or t[0] <= self.model_sampling.percent_to_sigma(pose_end):
kwargs.pop("pose_latents", None)
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@ -1511,11 +1511,11 @@ class WAN21_SCAIL(WAN21):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
out['pose_latents'] = pose_latents.shape
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
return out