From 987a937658c163961b2010c0293d0bac05d1cc4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Wed, 27 May 2026 22:08:06 +0300 Subject: [PATCH] Support context window for PiD and fix lq_latent rounding (#14136) --- comfy/ldm/pixeldit/pid.py | 5 +++-- comfy/model_base.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index 0ad4b7ce8..21b73907a 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -207,8 +207,9 @@ class PidNet(PixDiT_T2I): f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." ) B = x.shape[0] - Hs = x.shape[2] // self.patch_size - Ws = x.shape[3] // self.patch_size + # Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream. + Hs = -(-x.shape[2] // self.patch_size) + Ws = -(-x.shape[3] // self.patch_size) degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) if degrade_sigma.numel() == 1 and B > 1: diff --git a/comfy/model_base.py b/comfy/model_base.py index e55808633..205178911 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1428,6 +1428,23 @@ class PiD(PixelDiTT2I): out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + lq = cond_value.cond + dim = window.dim + if dim >= lq.ndim: + return None + lq_proj = self.diffusion_model.lq_proj + ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor + # Map x window indices -> lq indices (deduplicated, sorted, in-bounds). + lq_size = lq.size(dim) + lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size}) + if not lq_indices: + return None + idx = tuple([slice(None)] * dim + [lq_indices]) + return cond_value._copy_with(lq[idx].to(device)) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):