mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-07 03:52:32 +08:00
Fix some lowvram stuff with ace step 1.5 (#12312)
This commit is contained in:
parent
6555dc65b8
commit
458292fef0
@ -738,7 +738,7 @@ class AttentionPooler(nn.Module):
|
||||
def forward(self, x):
|
||||
B, T, P, D = x.shape
|
||||
x = self.embed_tokens(x)
|
||||
special = self.special_token.expand(B, T, 1, -1)
|
||||
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
|
||||
x = torch.cat([special, x], dim=2)
|
||||
x = x.view(B * T, P + 1, D)
|
||||
|
||||
@ -789,7 +789,7 @@ class FSQ(nn.Module):
|
||||
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
|
||||
|
||||
def bound(self, z):
|
||||
levels_minus_1 = (self._levels - 1).to(z.dtype)
|
||||
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
|
||||
scale = 2. / levels_minus_1
|
||||
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
|
||||
|
||||
@ -804,8 +804,8 @@ class FSQ(nn.Module):
|
||||
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
|
||||
|
||||
def codes_to_indices(self, zhat):
|
||||
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
|
||||
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
|
||||
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
|
||||
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
|
||||
|
||||
def forward(self, z):
|
||||
orig_dtype = z.dtype
|
||||
@ -887,7 +887,7 @@ class ResidualFSQ(nn.Module):
|
||||
x = self.project_in(x)
|
||||
|
||||
if hasattr(self, 'soft_clamp_input_value'):
|
||||
sc_val = self.soft_clamp_input_value.to(x.dtype)
|
||||
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
|
||||
x = (x / sc_val).tanh() * sc_val
|
||||
|
||||
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
|
||||
@ -895,7 +895,7 @@ class ResidualFSQ(nn.Module):
|
||||
all_indices = []
|
||||
|
||||
for layer, scale in zip(self.layers, self.scales):
|
||||
scale = scale.to(residual.dtype)
|
||||
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
|
||||
|
||||
quantized, indices = layer(residual / scale)
|
||||
quantized = quantized * scale
|
||||
|
||||
Loading…
Reference in New Issue
Block a user