mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
code rabbit suggestions
This commit is contained in:
parent
c5a750205d
commit
f3d4125e49
@ -188,7 +188,7 @@ class DINOv3ViTLayer(nn.Module):
|
||||
device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
@ -273,8 +273,8 @@ class DINOv3ViTModel(nn.Module):
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
self.norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = self.norm(hidden_states)
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return sequence_output, None, pooled_output, None
|
||||
|
||||
@ -18,6 +18,6 @@
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": false,
|
||||
"value_bias": true,
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225]
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ from typing import Dict, Callable
|
||||
|
||||
NO_TRITION = False
|
||||
try:
|
||||
allow_tf32 = torch.cuda.is_tf32_supported
|
||||
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||
except Exception:
|
||||
allow_tf32 = False
|
||||
try:
|
||||
|
||||
@ -306,10 +306,7 @@ class ModulatedSparseTransformerBlock(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||
if self.use_checkpoint:
|
||||
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
|
||||
else:
|
||||
return self._forward(x, mod)
|
||||
return self._forward(x, mod)
|
||||
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
@ -486,6 +483,7 @@ class SLatFlowModel(nn.Module):
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t = t.to(dtype)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
@ -790,7 +788,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
return h
|
||||
|
||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
t_shifted /= 1000.0
|
||||
t_shifted = t_shifted / 1000.0
|
||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||
t_new *= 1000.0
|
||||
|
||||
@ -117,12 +117,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = False
|
||||
if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
if "is_512" in metadata and metadata["metadata"]:
|
||||
if "is_512" in metadata:
|
||||
unet_config["resolution"] = 32
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
|
||||
Loading…
Reference in New Issue
Block a user