Code cleanup, don't force the fp32 layers as it has minimal effect

This commit is contained in:
kijai 2025-12-01 18:34:45 +02:00
parent c14bfb0554
commit dafa2695d4

View File

@ -39,13 +39,13 @@ class TimeEmbeddings(nn.Module):
self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep):
def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
@ -67,7 +67,7 @@ class VisualEmbeddings(nn.Module):
super().__init__()
self.patch_size = patch_size
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=torch.float32)
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C
@ -82,17 +82,17 @@ class VisualEmbeddings(nn.Module):
dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x.float()).to(x.dtype)
return self.in_layer(x)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32)
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
return self.out_layer(self.activation(x.float())).to(x.dtype)
return self.out_layer(self.activation(x))
class SelfAttention(nn.Module):
@ -125,13 +125,10 @@ class SelfAttention(nn.Module):
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
B, L, _ = x.shape
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = []
for i in range(0, L, chunk_size):
end_idx = min(i + chunk_size, L)
x_chunk = x[:, i:end_idx]
freqs_chunk = freqs[:, i:end_idx]
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
@ -174,13 +171,11 @@ class FeedForward(nn.Module):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
B, L, _ = x.shape
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
for i in range(0, L, chunk_size):
end_idx = min(i + chunk_size, L)
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
return output
chunks = torch.chunk(x, self.num_chunks, dim=1)
output_chunks = []
for chunk in chunks:
output_chunks.append(self._forward(chunk))
return torch.cat(output_chunks, dim=1)
def forward(self, x):
if x.shape[1] > 8192:
@ -367,7 +362,7 @@ class Kandinsky5(nn.Module):
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)