code rabbit suggestions

This commit is contained in:
Yousef Rafat 2026-02-20 20:16:49 +02:00
parent c5a750205d
commit f3d4125e49
5 changed files with 11 additions and 13 deletions

View File

@ -188,7 +188,7 @@ class DINOv3ViTLayer(nn.Module):
device, dtype, operations):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps)
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
@ -273,8 +273,8 @@ class DINOv3ViTModel(nn.Module):
position_embeddings=position_embeddings,
)
self.norm = self.norm.to(hidden_states.device)
sequence_output = self.norm(hidden_states)
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -18,6 +18,6 @@
"rope_theta": 100.0,
"use_gated_mlp": false,
"value_bias": true,
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225]
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -6,7 +6,7 @@ from typing import Dict, Callable
NO_TRITION = False
try:
allow_tf32 = torch.cuda.is_tf32_supported
allow_tf32 = torch.cuda.is_tf32_supported()
except Exception:
allow_tf32 = False
try:

View File

@ -306,10 +306,7 @@ class ModulatedSparseTransformerBlock(nn.Module):
return x
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
else:
return self._forward(x, mod)
return self._forward(x, mod)
class ModulatedSparseTransformerCrossBlock(nn.Module):
@ -486,6 +483,7 @@ class SLatFlowModel(nn.Module):
x = x.to(dtype)
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t = t.to(dtype)
t_emb = self.t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
@ -790,7 +788,7 @@ class SparseStructureFlowModel(nn.Module):
return h
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
t_shifted /= 1000.0
t_shifted = t_shifted / 1000.0
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
t_new *= 1000.0

View File

@ -117,12 +117,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["image_model"] = "trellis2"
unet_config["init_txt_model"] = False
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
unet_config["init_txt_model"] = True
unet_config["resolution"] = 64
if metadata is not None:
if "is_512" in metadata and metadata["metadata"]:
if "is_512" in metadata:
unet_config["resolution"] = 32
unet_config["num_heads"] = 12