Maybe fix context windows for v2v

This commit is contained in:
kijai 2026-06-02 01:44:44 +03:00
parent f87432bafb
commit 2c7d2561af

View File

@ -1524,6 +1524,20 @@ class WAN21(BaseModel):
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# In-context streams slicing (Bernini)
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
dim = window.dim
out = []
for lat in cond_value.cond:
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
idx = tuple([slice(None)] * dim + [window.index_list])
out.append(lat[idx].to(device))
else:
out.append(lat.to(device))
return cond_value._copy_with(out)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):