diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 97e86da19..59e8c0b6f 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,5 +1,3 @@ -import math - import torch import torch.nn.functional as F @@ -257,37 +255,6 @@ class Dino2Embeddings(torch.nn.Module): else: self.camera_token = None - def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor: - previous_dtype = x.dtype - npatch = x.shape[1] - 1 - N = self.position_embeddings.shape[1] - 1 - pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float() - if npatch == N and w == h: - return pos_embed - class_pos_embed = pos_embed[:, 0] - patch_pos_embed = pos_embed[:, 1:] - dim = x.shape[-1] - ph = h // self.patch_size # patch grid height - pw = w // self.patch_size # patch grid width - M = int(math.sqrt(N)) - assert N == M * M - # Historical 0.1 offset preserves bicubic resample compatibility with - # the original DINOv2 release; see the upstream PR for context. - # ``scale_factor`` is interpreted as (height_scale, width_scale) by - # ``F.interpolate`` so we must put the height scale FIRST. Earlier - # revisions of this function had it swapped which only worked for - # square inputs (e.g. CLIP-vision square crops); non-square inputs - # like DA3-Small / DA3-Base multi-view paths exposed the bug. - sh = float(ph + 0.1) / M - sw = float(pw + 0.1) / M - patch_pos_embed = F.interpolate( - patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), - scale_factor=(sh, sw), mode="bicubic", antialias=False, - ) - assert (ph, pw) == patch_pos_embed.shape[-2:] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) - def interpolate_pos_encoding(self, x, h_pixels, w_pixels): pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32) @@ -295,12 +262,22 @@ class Dino2Embeddings(torch.nn.Module): patch_pos = pos_embed[:, 1:] N = patch_pos.shape[1] M = int(N ** 0.5) + assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})" h0 = h_pixels // self.patch_size w0 = w_pixels // self.patch_size - scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # scale_factor is (height_scale, width_scale) -- height MUST come first; + # swapping these only happens to work for square inputs and breaks + # non-square paths like DA3-Small / DA3-Base multi-view. + scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2) patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False) + assert (h0, w0) == patch_pos.shape[-2:], ( + f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match " + f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); " + f"check scale_factor axis order and +0.1 rounding workaround" + ) patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2) return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)