Lower ltxv text encoder vram use. (#11713)

This commit is contained in:
comfyanonymous 2026-01-07 16:12:15 -08:00 committed by GitHub
parent 1c705f7bfb
commit 34751fe9f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module):
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)