mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
model fixes
This commit is contained in:
parent
f2c0320fe8
commit
cdd7ced1e8
@ -26,10 +26,9 @@ class SparseFeedForwardNet(nn.Module):
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
def manual_cast(tensor, dtype):
|
||||
if not torch.is_autocast_enabled():
|
||||
return tensor.type(dtype)
|
||||
return tensor
|
||||
def manual_cast(obj, dtype):
|
||||
return obj.to(dtype=dtype)
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
@ -88,6 +87,12 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
|
||||
def forward(self, q, k=None):
|
||||
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
freqs_cis = q.get_spatial_cache(cache_name)
|
||||
@ -111,11 +116,15 @@ class SparseRotaryPositionEmbedder(nn.Module):
|
||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if torch.is_complex(phases):
|
||||
phases = phases.to(torch.complex64)
|
||||
else:
|
||||
phases = phases.to(torch.float32)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
|
||||
)], dim=-1)
|
||||
return phases
|
||||
|
||||
@ -468,7 +477,7 @@ class SLatFlowModel(nn.Module):
|
||||
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t_emb = self.t_embedder(t)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
@ -687,9 +696,12 @@ class SparseStructureFlowModel(nn.Module):
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
operations=None,
|
||||
device = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
@ -706,7 +718,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels)
|
||||
self.t_embedder = TimestepEmbedder(model_channels, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
@ -743,9 +755,6 @@ class SparseStructureFlowModel(nn.Module):
|
||||
|
||||
self.out_layer = nn.Linear(model_channels, out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
self.convert_to(self.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
|
||||
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
|
||||
@ -755,7 +764,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
h = self.input_layer(h)
|
||||
if self.pe_mode == "ape":
|
||||
h = h + self.pos_emb[None]
|
||||
t_emb = self.t_embedder(t)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
@ -799,7 +808,6 @@ class Trellis2(nn.Module):
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
args.pop("out_channels")
|
||||
args.pop("in_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
|
||||
def forward(self, x: NestedTensor, timestep, context, **kwargs):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user