mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 11:50:16 +08:00
Code cleanup, don't force the fp32 layers as it has minimal effect
This commit is contained in:
parent
c14bfb0554
commit
dafa2695d4
@ -39,13 +39,13 @@ class TimeEmbeddings(nn.Module):
|
|||||||
self.max_period = max_period
|
self.max_period = max_period
|
||||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||||
operations = operation_settings.get("operations")
|
operations = operation_settings.get("operations")
|
||||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.activation = nn.SiLU()
|
self.activation = nn.SiLU()
|
||||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=torch.float32)
|
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
def forward(self, timestep):
|
def forward(self, timestep, dtype):
|
||||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||||
return time_embed
|
return time_embed
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ class VisualEmbeddings(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.patch_size = patch_size
|
self.patch_size = patch_size
|
||||||
operations = operation_settings.get("operations")
|
operations = operation_settings.get("operations")
|
||||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=torch.float32)
|
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||||
@ -82,17 +82,17 @@ class VisualEmbeddings(nn.Module):
|
|||||||
dim,
|
dim,
|
||||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||||
|
|
||||||
return self.in_layer(x.float()).to(x.dtype)
|
return self.in_layer(x)
|
||||||
|
|
||||||
|
|
||||||
class Modulation(nn.Module):
|
class Modulation(nn.Module):
|
||||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation = nn.SiLU()
|
self.activation = nn.SiLU()
|
||||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=torch.float32)
|
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.out_layer(self.activation(x.float())).to(x.dtype)
|
return self.out_layer(self.activation(x))
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
@ -125,13 +125,10 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||||
def process_chunks(proj_fn, norm_fn):
|
def process_chunks(proj_fn, norm_fn):
|
||||||
B, L, _ = x.shape
|
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
|
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||||
chunks = []
|
chunks = []
|
||||||
for i in range(0, L, chunk_size):
|
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||||
end_idx = min(i + chunk_size, L)
|
|
||||||
x_chunk = x[:, i:end_idx]
|
|
||||||
freqs_chunk = freqs[:, i:end_idx]
|
|
||||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||||
return torch.cat(chunks, dim=1)
|
return torch.cat(chunks, dim=1)
|
||||||
|
|
||||||
@ -174,13 +171,11 @@ class FeedForward(nn.Module):
|
|||||||
return self.out_layer(self.activation(self.in_layer(x)))
|
return self.out_layer(self.activation(self.in_layer(x)))
|
||||||
|
|
||||||
def _forward_chunked(self, x):
|
def _forward_chunked(self, x):
|
||||||
B, L, _ = x.shape
|
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
chunk_size = (L + self.num_chunks - 1) // self.num_chunks
|
output_chunks = []
|
||||||
output = torch.empty(B, L, self.out_layer.out_features, dtype=x.dtype, device=x.device)
|
for chunk in chunks:
|
||||||
for i in range(0, L, chunk_size):
|
output_chunks.append(self._forward(chunk))
|
||||||
end_idx = min(i + chunk_size, L)
|
return torch.cat(output_chunks, dim=1)
|
||||||
output[:, i:end_idx] = self._forward(x[:, i:end_idx])
|
|
||||||
return output
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if x.shape[1] > 8192:
|
if x.shape[1] > 8192:
|
||||||
@ -367,7 +362,7 @@ class Kandinsky5(nn.Module):
|
|||||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
context = self.text_embeddings(context)
|
context = self.text_embeddings(context)
|
||||||
time_embed = self.time_embeddings(timestep).to(x.dtype) + self.pooled_text_embeddings(y)
|
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||||
|
|
||||||
for block in self.text_transformer_blocks:
|
for block in self.text_transformer_blocks:
|
||||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user