diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 484622d76..2367fc42c 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -26,10 +26,9 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -def manual_cast(tensor, dtype): - if not torch.is_autocast_enabled(): - return tensor.type(dtype) - return tensor +def manual_cast(obj, dtype): + return obj.to(dtype=dtype) + class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype @@ -88,6 +87,12 @@ class SparseRotaryPositionEmbedder(nn.Module): return freqs_cis + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + def forward(self, q, k=None): cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' freqs_cis = q.get_spatial_cache(cache_name) @@ -111,11 +116,15 @@ class SparseRotaryPositionEmbedder(nn.Module): class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): def forward(self, indices: torch.Tensor) -> torch.Tensor: phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if torch.is_complex(phases): + phases = phases.to(torch.complex64) + else: + phases = phases.to(torch.float32) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] phases = torch.cat([phases, torch.polar( - torch.ones(*phases.shape[:-1], padn, device=phases.device), - torch.zeros(*phases.shape[:-1], padn, device=phases.device) + torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32), + torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32) )], dim=-1) return phases @@ -468,7 +477,7 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -687,9 +696,12 @@ class SparseStructureFlowModel(nn.Module): initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, + operations=None, + device = None, **kwargs ): super().__init__() + self.device = device self.resolution = resolution self.in_channels = in_channels self.model_channels = model_channels @@ -706,7 +718,7 @@ class SparseStructureFlowModel(nn.Module): self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype - self.t_embedder = TimestepEmbedder(model_channels) + self.t_embedder = TimestepEmbedder(model_channels, operations=operations) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), @@ -743,9 +755,6 @@ class SparseStructureFlowModel(nn.Module): self.out_layer = nn.Linear(model_channels, out_channels) - self.initialize_weights() - self.convert_to(self.dtype) - def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" @@ -755,7 +764,7 @@ class SparseStructureFlowModel(nn.Module): h = self.input_layer(h) if self.pe_mode == "ape": h = h + self.pos_emb[None] - t_emb = self.t_embedder(t) + t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) @@ -799,7 +808,6 @@ class Trellis2(nn.Module): self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) args.pop("out_channels") - args.pop("in_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x: NestedTensor, timestep, context, **kwargs):