diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b54f7f39a..cb44ee6e8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -93,6 +93,50 @@ class IndexListCallbacks: return {} +def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]): + if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)): + return None + cond_tensor = cond_value.cond + if temporal_dim >= cond_tensor.ndim: + return None + + cond_size = cond_tensor.size(temporal_dim) + + if temporal_scale == 1: + expected_size = x_in.size(window.dim) - temporal_offset + if cond_size != expected_size: + return None + + if temporal_offset == 0 and temporal_scale == 1: + sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list) + return cond_value._copy_with(sliced) + + # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) + if temporal_offset > 0: + indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + indices = [i for i in indices if 0 <= i] + else: + indices = list(window.index_list) + + if not indices: + return None + + if temporal_scale > 1: + scaled = [] + for i in indices: + for k in range(temporal_scale): + si = i * temporal_scale + k + if si < cond_size: + scaled.append(si) + indices = scaled + if not indices: + return None + + idx = tuple([slice(None)] * temporal_dim + [indices]) + sliced = cond_tensor[idx].to(device) + return cond_value._copy_with(sliced) + + @dataclass class ContextSchedule: name: str @@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item[cond_key] = result handled = True break + if not handled and self._model is not None: + result = self._model.resize_cond_for_context_window( + cond_key, cond_value, window, x_in, device, + retain_index_list=self.cond_retain_index_list) + if result is not None: + new_cond_item[cond_key] = result + handled = True if handled: continue if isinstance(cond_value, torch.Tensor): - if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) # Handle audio_embed (temporal dim is 1) @@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC): return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self._model = model self.set_step(timestep, model_options) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) diff --git a/comfy/model_base.py b/comfy/model_base.py index d9d5a9293..88905e191 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -285,6 +285,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1375,6 +1381,12 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1427,6 +1439,12 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1444,6 +1462,14 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + import comfy.context_windows + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1480,6 +1506,12 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + import comfy.context_windows + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 93a5204e1..0e43f2e44 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode): io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."),