debugging

This commit is contained in:
Yousef Rafat 2026-02-20 22:04:37 +02:00
parent b3da8ed4c5
commit 1fde60b2bc
3 changed files with 27 additions and 73 deletions

View File

@ -175,7 +175,7 @@ class DINOv3ViTEmbeddings(nn.Module):
cls_token = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
device = patch_embeddings
device = patch_embeddings.device
cls_token = cls_token.to(device)
register_tokens = register_tokens.to(device)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)

View File

@ -201,6 +201,8 @@ class SparseMultiHeadAttention(nn.Module):
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
if self._type == "self":
dtype = next(self.to_qkv.parameters()).dtype
x = x.to(dtype)
qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3)
if self.qk_rms_norm or self.use_rope:
@ -243,71 +245,6 @@ class SparseMultiHeadAttention(nn.Module):
h = self._linear(self.to_out, h)
return h
class ModulatedSparseTransformerBlock(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
return self._forward(x, mod)
class ModulatedSparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
@ -483,15 +420,13 @@ class SLatFlowModel(nn.Module):
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t = t.to(dtype)
t_emb = self.t_embedder(t, out_dtype = t.dtype)
t_embedder = self.t_embedder.to(dtype)
t_emb = t_embedder(t, out_dtype = t.dtype)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
cond = manual_cast(cond, self.dtype)
if self.pe_mode == "ape":
pe = self.pos_embedder(h.coords[:, 1:])
h = h + manual_cast(pe, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
@ -849,7 +784,24 @@ class Trellis2(nn.Module):
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
if not_struct_mode:
x = SparseTensor(feats=x, coords=coords)
B, N, C = x.shape
if mode == "shape_generation":
feats_flat = x.reshape(-1, C)
# 3. inflate coords [N, 4] -> [B*N, 4]
coords_list = []
for i in range(B):
c = coords.clone()
c[:, 0] = i
coords_list.append(c)
batched_coords = torch.cat(coords_list, dim=0)
else: # TODO: texture
# may remove the else if texture doesn't require special handling
batched_coords = coords
feats_flat = x
x = SparseTensor(feats=feats_flat, coords=batched_coords)
if mode == "shape_generation":
# TODO
@ -868,4 +820,6 @@ class Trellis2(nn.Module):
if not_struct_mode:
out = out.feats
if mode == "shape_generation":
out = out.view(B, N, -1)
return out

View File

@ -238,9 +238,9 @@ class Trellis2Conditioning(IO.ComfyNode):
max_size = max(image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
image = torch.tensor(np.array(image)).unsqueeze(0)
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
# could make 1024 an option
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)