diff --git a/comfy/model_base.py b/comfy/model_base.py index 88155b9ae..f5224a840 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1524,6 +1524,20 @@ class WAN21(BaseModel): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # In-context streams slicing (Bernini) + if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list): + dim = window.dim + out = [] + for lat in cond_value.cond: + if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]: + idx = tuple([slice(None)] * dim + [window.index_list]) + out.append(lat[idx].to(device)) + else: + out.append(lat.to(device)) + return cond_value._copy_with(out) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_CausalAR(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):