mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 16:57:29 +08:00
Minor adjustments/safeguards
This commit is contained in:
parent
f927863377
commit
e5bb1fe22a
@ -1628,9 +1628,8 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
|
||||
reference_latent = kwargs.get("reference_latent", None)
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
@ -1726,8 +1725,10 @@ class SCAILWanModel(WanModel):
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
reference_latent = None
|
||||
if "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
@ -1498,7 +1498,7 @@ class WAN21_SCAIL(WAN21):
|
||||
return out
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, pose_start=0.0, pose_end=1.0, **kwargs):
|
||||
if t >= self.model_sampling.percent_to_sigma(pose_start) or t <= self.model_sampling.percent_to_sigma(pose_end):
|
||||
if t[0] >= self.model_sampling.percent_to_sigma(pose_start) or t[0] <= self.model_sampling.percent_to_sigma(pose_end):
|
||||
kwargs.pop("pose_latents", None)
|
||||
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@ -1511,11 +1511,11 @@ class WAN21_SCAIL(WAN21):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = pose_latents.shape
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user