Cleanup of comments.

This commit is contained in:
Talmaj Marinc 2026-05-19 14:55:42 +02:00
parent 2686038f94
commit 5e889e73b9
6 changed files with 7 additions and 64 deletions

View File

@ -106,11 +106,6 @@ class _Block(nn.Module):
return x
# -----------------------------------------------------------------------------
# Camera encoder
# -----------------------------------------------------------------------------
class CameraEnc(nn.Module):
"""Encode per-view (extrinsics, intrinsics) into a camera token.
@ -165,11 +160,6 @@ class CameraEnc(nn.Module):
return tokens
# -----------------------------------------------------------------------------
# Camera decoder
# -----------------------------------------------------------------------------
class CameraDec(nn.Module):
"""Decode the final cam token into a 9-D pose encoding.

View File

@ -65,8 +65,7 @@ def _build_backbone_config(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# DA3 weights have no mask_token; skip registering it to avoid spurious
# missing-key warnings on load.
# No mask_token in DA3 weights; omit param to avoid load warnings.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
@ -149,10 +148,7 @@ class DepthAnything3Net(nn.Module):
)
self.head = head_cls(**head_kwargs)
# Camera encoder / decoder are only constructed when their weights are
# present in the checkpoint; the multi-view / pose forward path becomes
# available accordingly. ``cam_enc.dim_out`` matches the backbone's
# ``embed_dim`` so the cam token slots into block ``alt_start``.
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
@ -163,8 +159,6 @@ class DepthAnything3Net(nn.Module):
else:
self.cam_enc = None
if has_cam_dec:
# Default cam_dec dim_in is 2*embed_dim when cat_token is on
# (the cls/cam token in the output is the cat'd version).
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
@ -175,9 +169,6 @@ class DepthAnything3Net(nn.Module):
self.dtype = dtype
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
image: torch.Tensor,

View File

@ -24,9 +24,7 @@ from typing import Optional, Tuple
import torch
# -----------------------------------------------------------------------------
# Linear-algebra helpers
# -----------------------------------------------------------------------------
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

View File

@ -766,7 +766,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
# Depth Anything 3 (Apache-2.0 monocular variants: Small/Base/Mono-Large/Metric-Large).
# Depth Anything 3
if '{}backbone.pretrained.patch_embed.proj.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "DepthAnything3"

View File

@ -1864,41 +1864,16 @@ class DepthAnything3(supported_models_base.BASE):
return None
def process_unet_state_dict(self, state_dict):
# Drop weights for components we do not build (3D Gaussian heads).
# ``cam_enc.*`` / ``cam_dec.*`` are kept and consumed by the multi-view
# forward path -- their layouts in our ``camera.py`` mirror the
# upstream ``cam_enc.py`` / ``cam_dec.py`` so HF safetensors load
# directly without any key remap.
# Drop Gaussian-head weights; remap fused backbone QKV to Dinov2Model layout.
drop_prefixes = ("gs_head.", "gs_adapter.")
for k in list(state_dict.keys()):
if k.startswith(drop_prefixes):
state_dict.pop(k)
# Remap upstream DA3 backbone keys (``backbone.pretrained.*`` with
# fused QKV) to the layout used by ``comfy.image_encoders.dino2.Dinov2Model``.
return _da3_remap_backbone_keys(state_dict, prefix="backbone.")
def _da3_remap_backbone_keys(state_dict, prefix="backbone."):
"""Rewrite upstream DA3 DINOv2 keys to the shared ``Dinov2Model`` layout.
Upstream layout (under ``{prefix}pretrained.``):
patch_embed.proj.{weight,bias}, pos_embed, cls_token, camera_token, norm.*,
blocks.{i}.norm{1,2}.*, blocks.{i}.attn.qkv.{weight,bias},
blocks.{i}.attn.q_norm.*, blocks.{i}.attn.k_norm.*,
blocks.{i}.attn.proj.*, blocks.{i}.ls{1,2}.gamma,
blocks.{i}.mlp.fc{1,2}.* (or w12/w3 for SwiGLU)
Target layout (Dinov2Model under ``{prefix}``):
embeddings.patch_embeddings.projection.*,
embeddings.position_embeddings, embeddings.cls_token, embeddings.camera_token,
layernorm.*,
encoder.layer.{i}.norm{1,2}.*,
encoder.layer.{i}.attention.attention.{query,key,value}.*,
encoder.layer.{i}.attention.q_norm.*, encoder.layer.{i}.attention.k_norm.*,
encoder.layer.{i}.attention.output.dense.*,
encoder.layer.{i}.layer_scale{1,2}.lambda1,
encoder.layer.{i}.mlp.fc{1,2}.* (or weights_in/weights_out for SwiGLU)
"""
"""Map ``backbone.pretrained.*`` (upstream DA3) keys to ``Dinov2Model`` under ``prefix``."""
pre = prefix + "pretrained."
src_keys = [k for k in state_dict.keys() if k.startswith(pre)]
if not src_keys:

View File

@ -72,19 +72,9 @@ class LoadDepthAnything3(io.ComfyNode):
return io.NodeOutput(model)
# -----------------------------------------------------------------------------
# Inference helpers
# -----------------------------------------------------------------------------
def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
method: str = "upper_bound_resize"):
"""Run the DA3 network on a (B, H, W, 3) IMAGE batch.
Returns ``(depth, confidence, sky)`` tensors at the original image
resolution. ``confidence`` / ``sky`` are ``None`` when the variant does
not produce them.
"""
"""Run DA3 on ``(B,H,W,3)`` IMAGE; returns depth/conf/sky at original resolution (or None)."""
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
@ -95,7 +85,6 @@ def _run_da3(model_patcher, image: torch.Tensor, process_res: int,
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
depths, confs, skies = [], [], []
# Process one image at a time to keep peak memory predictable.
for i in range(B):
single = image[i:i + 1].to(device)
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)