mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
debugging
This commit is contained in:
parent
b3da8ed4c5
commit
1fde60b2bc
@ -175,7 +175,7 @@ class DINOv3ViTEmbeddings(nn.Module):
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||
device = patch_embeddings
|
||||
device = patch_embeddings.device
|
||||
cls_token = cls_token.to(device)
|
||||
register_tokens = register_tokens.to(device)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
|
||||
@ -201,6 +201,8 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
|
||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
|
||||
if self._type == "self":
|
||||
dtype = next(self.to_qkv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
qkv = self._linear(self.to_qkv, x)
|
||||
qkv = self._fused_pre(qkv, num_fused=3)
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
@ -243,71 +245,6 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
|
||||
class ModulatedSparseTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(channels, 6 * channels, bias=True)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
|
||||
return self._forward(x, mod)
|
||||
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
@ -483,15 +420,13 @@ class SLatFlowModel(nn.Module):
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t = t.to(dtype)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
t_embedder = self.t_embedder.to(dtype)
|
||||
t_emb = t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
|
||||
if self.pe_mode == "ape":
|
||||
pe = self.pos_embedder(h.coords[:, 1:])
|
||||
h = h + manual_cast(pe, self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond)
|
||||
|
||||
@ -849,7 +784,24 @@ class Trellis2(nn.Module):
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if not_struct_mode:
|
||||
x = SparseTensor(feats=x, coords=coords)
|
||||
B, N, C = x.shape
|
||||
|
||||
if mode == "shape_generation":
|
||||
feats_flat = x.reshape(-1, C)
|
||||
|
||||
# 3. inflate coords [N, 4] -> [B*N, 4]
|
||||
coords_list = []
|
||||
for i in range(B):
|
||||
c = coords.clone()
|
||||
c[:, 0] = i
|
||||
coords_list.append(c)
|
||||
|
||||
batched_coords = torch.cat(coords_list, dim=0)
|
||||
else: # TODO: texture
|
||||
# may remove the else if texture doesn't require special handling
|
||||
batched_coords = coords
|
||||
feats_flat = x
|
||||
x = SparseTensor(feats=feats_flat, coords=batched_coords)
|
||||
|
||||
if mode == "shape_generation":
|
||||
# TODO
|
||||
@ -868,4 +820,6 @@ class Trellis2(nn.Module):
|
||||
|
||||
if not_struct_mode:
|
||||
out = out.feats
|
||||
if mode == "shape_generation":
|
||||
out = out.view(B, N, -1)
|
||||
return out
|
||||
|
||||
@ -238,9 +238,9 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
max_size = max(image.size)
|
||||
scale = min(1, 1024 / max_size)
|
||||
if scale < 1:
|
||||
image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
|
||||
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
|
||||
|
||||
image = torch.tensor(np.array(image)).unsqueeze(0)
|
||||
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
|
||||
|
||||
# could make 1024 an option
|
||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user