mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-21 06:27:24 +08:00
Merge interpolate_pos_encoding method in dinov2
This commit is contained in:
parent
2ba487e970
commit
e26ba849e6
@ -1,5 +1,3 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -257,37 +255,6 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def _interpolate_pos_encoding(self, x: torch.Tensor, h: int, w: int) -> torch.Tensor:
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.position_embeddings.shape[1] - 1
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype).float()
|
||||
if npatch == N and w == h:
|
||||
return pos_embed
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
ph = h // self.patch_size # patch grid height
|
||||
pw = w // self.patch_size # patch grid width
|
||||
M = int(math.sqrt(N))
|
||||
assert N == M * M
|
||||
# Historical 0.1 offset preserves bicubic resample compatibility with
|
||||
# the original DINOv2 release; see the upstream PR for context.
|
||||
# ``scale_factor`` is interpreted as (height_scale, width_scale) by
|
||||
# ``F.interpolate`` so we must put the height scale FIRST. Earlier
|
||||
# revisions of this function had it swapped which only worked for
|
||||
# square inputs (e.g. CLIP-vision square crops); non-square inputs
|
||||
# like DA3-Small / DA3-Base multi-view paths exposed the bug.
|
||||
sh = float(ph + 0.1) / M
|
||||
sw = float(pw + 0.1) / M
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sh, sw), mode="bicubic", antialias=False,
|
||||
)
|
||||
assert (ph, pw) == patch_pos_embed.shape[-2:]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
|
||||
@ -295,12 +262,22 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# scale_factor is (height_scale, width_scale) -- height MUST come first;
|
||||
# swapping these only happens to work for square inputs and breaks
|
||||
# non-square paths like DA3-Small / DA3-Base multi-view.
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
assert (h0, w0) == patch_pos.shape[-2:], (
|
||||
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
|
||||
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
|
||||
f"check scale_factor axis order and +0.1 rounding workaround"
|
||||
)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user