Add comment

This commit is contained in:
kijai 2026-05-26 09:50:48 +03:00
parent 69fe4139c8
commit 58dedcee72

View File

@ -210,7 +210,7 @@ class PiTBlock(nn.Module):
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous()
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks