Merge interpolate_pos_encoding method in dinov2

This commit is contained in:
Talmaj Marinc 2026-05-19 12:03:51 +02:00
parent 2ba487e970
commit e26ba849e6

View File

@ -1,5 +1,3 @@
import math
import torch
import torch.nn.functional as F
@ -257,37 +255,6 @@ class Dino2Embeddings(torch.nn.Module):
else:
self.camera_token = None
def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.position_embeddings.shape[1] - 1
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float()
if npatch == N and w == h:
return pos_embed
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
ph = h // self.patch_size # patch grid height
pw = w // self.patch_size # patch grid width
M = int(math.sqrt(N))
assert N == M * M
# Historical 0.1 offset preserves bicubic resample compatibility with
# the original DINOv2 release; see the upstream PR for context.
# ``scale_factor`` is interpreted as (height_scale, width_scale) by
# ``F.interpolate`` so we must put the height scale FIRST. Earlier
# revisions of this function had it swapped which only worked for
# square inputs (e.g. CLIP-vision square crops); non-square inputs
# like DA3-Small / DA3-Base multi-view paths exposed the bug.
sh = float(ph + 0.1) / M
sw = float(pw + 0.1) / M
patch_pos_embed = F.interpolate(
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
scale_factor=(sh, sw), mode="bicubic", antialias=False,
)
assert (ph, pw) == patch_pos_embed.shape[-2:]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
@ -295,12 +262,22 @@ class Dino2Embeddings(torch.nn.Module):
patch_pos = pos_embed[:, 1:]
N = patch_pos.shape[1]
M = int(N ** 0.5)
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
h0 = h_pixels // self.patch_size
w0 = w_pixels // self.patch_size
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# scale_factor is (height_scale, width_scale) -- height MUST come first;
# swapping these only happens to work for square inputs and breaks
# non-square paths like DA3-Small / DA3-Base multi-view.
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
assert (h0, w0) == patch_pos.shape[-2:], (
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
f"check scale_factor axis order and +0.1 rounding workaround"
)
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)