From dafa2695d4796fcd2b3d4fc05fdba0856092d9f7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Mon, 1 Dec 2025 18:34:45 +0200 Subject: [PATCH] Code cleanup, don't force the fp32 layers as it has minimal effect --- comfy/ldm/kandinsky5/model.py | 39 +++++++++++++++-------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 2bcaae195..a653e02fc 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -39,13 +39,13 @@ class TimeEmbeddings(nn.Module): self.max_period = max_period self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) operations = operation_settings.get("operations") - self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.activation = nn.SiLU() - self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32) + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - def forward(self, timestep): + def forward(self, timestep, dtype): args = torch.outer(timestep, self.freqs.to(device=timestep.device)) - time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed @@ -67,7 +67,7 @@ class VisualEmbeddings(nn.Module): super().__init__() self.patch_size = patch_size operations = operation_settings.get("operations") - self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=torch.float32) + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, x): x = x.movedim(1, -1) # B C T H W -> B T H W C @@ -82,17 +82,17 @@ class VisualEmbeddings(nn.Module): dim, ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) - return self.in_layer(x.float()).to(x.dtype) + return self.in_layer(x) class Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params, operation_settings=None): super().__init__() self.activation = nn.SiLU() - self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32) + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, x): - return self.out_layer(self.activation(x.float())).to(x.dtype) + return self.out_layer(self.activation(x)) class SelfAttention(nn.Module): @@ -125,13 +125,10 @@ class SelfAttention(nn.Module): def _forward_chunked(self, x, freqs, transformer_options={}): def process_chunks(proj_fn, norm_fn): - B, L, _ = x.shape - chunk_size = (L + self.num_chunks - 1) // self.num_chunks + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) chunks = [] - for i in range(0, L, chunk_size): - end_idx = min(i + chunk_size, L) - x_chunk = x[:, i:end_idx] - freqs_chunk = freqs[:, i:end_idx] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) return torch.cat(chunks, dim=1) @@ -174,13 +171,11 @@ class FeedForward(nn.Module): return self.out_layer(self.activation(self.in_layer(x))) def _forward_chunked(self, x): - B, L, _ = x.shape - chunk_size = (L + self.num_chunks - 1) // self.num_chunks - output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device) - for i in range(0, L, chunk_size): - end_idx = min(i + chunk_size, L) - output[:, i:end_idx] = self._forward(x[:, i:end_idx]) - return output + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) def forward(self, x): if x.shape[1] > 8192: @@ -367,7 +362,7 @@ class Kandinsky5(nn.Module): def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): patches_replace = transformer_options.get("patches_replace", {}) context = self.text_embeddings(context) - time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) for block in self.text_transformer_blocks: context = block(context, time_embed, freqs_text, transformer_options=transformer_options)