Cleanup NAF

This commit is contained in:
kijai 2026-07-02 11:24:16 +03:00
parent 29e2118717
commit be67e2366d
2 changed files with 2 additions and 11 deletions

View File

@ -157,7 +157,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
for k in keys:
if k not in u:
sd.pop(k)
# NAF feature upsampler ships bundled into the DINOv3 file under the `naf.` prefix.
# NAF feature upsampler bundled into the DINOv3 file under the `naf.` prefix.
naf_keys = [k for k in sd if k.startswith("naf.")]
if naf_keys:
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}

View File

@ -174,16 +174,9 @@ class RoPE(nn.Module):
self.D_head = embed_dim // num_heads
self.base = base
self.register_buffer("periods", torch.empty(self.D_head // 4), persistent=True) # loaded from the checkpoint
self._cached_key = None
self._cached_cos_sin = None
def _cos_sin(self, H: int, W: int, dtype: torch.dtype):
"""cos/sin only depend on (H, W) and the output dtype (periods are fixed
once loaded from the checkpoint), so cache them saves the meshgrid /
angle / cos / sin / tile / flatten on every forward."""
key = (H, W, dtype)
if self._cached_key == key and self._cached_cos_sin is not None:
return self._cached_cos_sin
"""cos/sin depend only on (H, W, dtype) and the checkpoint-fixed periods; recomputed per forward."""
device = self.periods.device
coords_h = torch.arange(0.5, H, device=device, dtype=torch.float32) / H
coords_w = torch.arange(0.5, W, device=device, dtype=torch.float32) / W
@ -193,8 +186,6 @@ class RoPE(nn.Module):
angles = angles.flatten(1, 2).tile(2) # [HW, D]
cos = torch.cos(angles).to(dtype)
sin = torch.sin(angles).to(dtype)
self._cached_cos_sin = (cos, sin)
self._cached_key = key
return cos, sin
def forward(self, x: torch.Tensor) -> torch.Tensor: