Fix some lowvram stuff with ace step 1.5 (#12312)

This commit is contained in:
comfyanonymous 2026-02-05 16:15:04 -08:00 committed by GitHub
parent 6555dc65b8
commit 458292fef0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -738,7 +738,7 @@ class AttentionPooler(nn.Module):
def forward(self, x):
B, T, P, D = x.shape
x = self.embed_tokens(x)
special = self.special_token.expand(B, T, 1, -1)
special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1)
x = torch.cat([special, x], dim=2)
x = x.view(B * T, P + 1, D)
@ -789,7 +789,7 @@ class FSQ(nn.Module):
self.register_buffer('implicit_codebook', implicit_codebook, persistent=False)
def bound(self, z):
levels_minus_1 = (self._levels - 1).to(z.dtype)
levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1)
scale = 2. / levels_minus_1
bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5
@ -804,8 +804,8 @@ class FSQ(nn.Module):
return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1.
def codes_to_indices(self, zhat):
zhat_normalized = (zhat + 1.) / (2. / (self._levels.to(zhat.dtype) - 1))
return (zhat_normalized * self._basis.to(zhat.dtype)).sum(dim=-1).round().to(torch.int32)
zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1))
return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32)
def forward(self, z):
orig_dtype = z.dtype
@ -887,7 +887,7 @@ class ResidualFSQ(nn.Module):
x = self.project_in(x)
if hasattr(self, 'soft_clamp_input_value'):
sc_val = self.soft_clamp_input_value.to(x.dtype)
sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype)
x = (x / sc_val).tanh() * sc_val
quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype)
@ -895,7 +895,7 @@ class ResidualFSQ(nn.Module):
all_indices = []
for layer, scale in zip(self.layers, self.scales):
scale = scale.to(residual.dtype)
scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype)
quantized, indices = layer(residual / scale)
quantized = quantized * scale