From d22cc94f53c0739f07b090afdb4da3bd29c76bb8 Mon Sep 17 00:00:00 2001 From: "jaystack.dev" <139488800+HyperExtendedReality@users.noreply.github.com> Date: Thu, 15 Jan 2026 14:44:03 -0500 Subject: [PATCH] fix: ensure LTXAVTEModel uses half-precision for SageAttention compatibility - Add automatic detection and default to bfloat16 (or fp16 fallback) when no explicit dtype is provided, based on device capabilities - Respect provided dtype_llama/dtype consistently across Gemma model, projection layer, and connectors - Remove forced `out.float()` in encode_token_weights to prevent downgrading to fp32 after projection - This allows SageAttention's optimized kernel to run instead of falling back to PyTorch attention Fixes the warning: "Error running sage attention: Input tensors must be in dtype of torch.float16 or torch.bfloat16, using pytorch attention instead." --- comfy/text_encoders/lt.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index c33c77db7..b1977834e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -61,11 +61,26 @@ class LTXAVTEModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): super().__init__() self.dtypes = set() - self.dtypes.add(dtype) + exec_device = device + if exec_device == "cpu": + exec_device = comfy.model_management.get_torch_device() - self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + if dtype_llama is None and dtype is None: + if comfy.model_management.should_use_bf16(exec_device): + dtype_llama = torch.bfloat16 + dtype = torch.bfloat16 + else: + dtype_llama = torch.float16 if hasattr(comfy.model_management, "should_use_fp16") and comfy.model_management.should_use_fp16(exec_device) else torch.float32 + dtype = dtype_llama + elif dtype_llama is None: + dtype_llama = dtype + elif dtype is None: + dtype = dtype_llama + + self.dtypes.add(dtype) self.dtypes.add(dtype_llama) + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) operations = self.gemma3_12b.operations # TODO self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) @@ -104,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module): out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = out.reshape((out.shape[0], out.shape[1], -1)) out = self.text_embedding_projection(out) - out = out.float() + #out = out.float() out_vid = self.video_embeddings_connector(out)[0] out_audio = self.audio_embeddings_connector(out)[0] out = torch.concat((out_vid, out_audio), dim=-1) @@ -118,9 +133,8 @@ class LTXAVTEModel(torch.nn.Module): sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) if len(sdo) == 0: sdo = sd - missing, unexpected = self.load_state_dict(sdo, strict=False) - missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model - return (missing, unexpected) + + return self.load_state_dict(sdo, strict=False) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0