Code cleanup, don't force the fp32 layers as it has minimal effect

This commit is contained in:
kijai 2025-12-01 18:34:45 +02:00
parent c14bfb0554
commit dafa2695d4

View File

@ -39,13 +39,13 @@ class TimeEmbeddings(nn.Module):
self.max_period = max_period self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations") operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU() self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep): def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device)) args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed return time_embed
@ -67,7 +67,7 @@ class VisualEmbeddings(nn.Module):
super().__init__() super().__init__()
self.patch_size = patch_size self.patch_size = patch_size
operations = operation_settings.get("operations") operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=torch.float32) self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x): def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C x = x.movedim(1, -1) # B C T H W -> B T H W C
@ -82,17 +82,17 @@ class VisualEmbeddings(nn.Module):
dim, dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x.float()).to(x.dtype) return self.in_layer(x)
class Modulation(nn.Module): class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None): def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__() super().__init__()
self.activation = nn.SiLU() self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32) self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x): def forward(self, x):
return self.out_layer(self.activation(x.float())).to(x.dtype) return self.out_layer(self.activation(x))
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
@ -125,13 +125,10 @@ class SelfAttention(nn.Module):
def _forward_chunked(self, x, freqs, transformer_options={}): def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn): def process_chunks(proj_fn, norm_fn):
B, L, _ = x.shape x_chunks = torch.chunk(x, self.num_chunks, dim=1)
chunk_size = (L + self.num_chunks - 1) // self.num_chunks freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = [] chunks = []
for i in range(0, L, chunk_size): for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
end_idx = min(i + chunk_size, L)
x_chunk = x[:, i:end_idx]
freqs_chunk = freqs[:, i:end_idx]
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1) return torch.cat(chunks, dim=1)
@ -174,13 +171,11 @@ class FeedForward(nn.Module):
return self.out_layer(self.activation(self.in_layer(x))) return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x): def _forward_chunked(self, x):
B, L, _ = x.shape chunks = torch.chunk(x, self.num_chunks, dim=1)
chunk_size = (L + self.num_chunks - 1) // self.num_chunks output_chunks = []
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device) for chunk in chunks:
for i in range(0, L, chunk_size): output_chunks.append(self._forward(chunk))
end_idx = min(i + chunk_size, L) return torch.cat(output_chunks, dim=1)
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
return output
def forward(self, x): def forward(self, x):
if x.shape[1] > 8192: if x.shape[1] > 8192:
@ -367,7 +362,7 @@ class Kandinsky5(nn.Module):
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context) context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y) time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks: for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options) context = block(context, time_embed, freqs_text, transformer_options=transformer_options)