mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 01:39:25 +08:00
Improve context window resizing for SCAIL2
This commit is contained in:
parent
039ed38ed1
commit
1e0a250832
@ -1816,7 +1816,24 @@ class WAN21_SCAIL2(WAN21_SCAIL):
|
|||||||
|
|
||||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
if cond_key in ("sam_latents", "pose_latents"):
|
if cond_key in ("sam_latents", "pose_latents"):
|
||||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
# Return sliced view omitting retain_index_list
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
|
||||||
|
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
# The ref mask is just a single frame padded with zero frames, so just grab the first frames for all windows
|
||||||
|
full_ref_mask = cond_value.cond
|
||||||
|
video_frame_count = x_in.shape[2]
|
||||||
|
if full_ref_mask.shape[2] != video_frame_count + 1:
|
||||||
|
return None
|
||||||
|
window_length = len(window.index_list)
|
||||||
|
|
||||||
|
# account for the causal anchor frame at the end of the ref mask if it exists
|
||||||
|
anchor_index = getattr(window, "causal_anchor_index", None)
|
||||||
|
if anchor_index is not None and anchor_index >= 0:
|
||||||
|
window_length += 1
|
||||||
|
|
||||||
|
window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device)
|
||||||
|
return cond_value._copy_with(window_ref_mask)
|
||||||
|
|
||||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||||
|
|
||||||
def concat_cond(self, **kwargs):
|
def concat_cond(self, **kwargs):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user