fix: ensure LTXAVTEModel uses half-precision for SageAttention compatibility

- Add automatic detection and default to bfloat16 (or fp16 fallback) when no explicit dtype is provided, based on device capabilities
- Respect provided dtype_llama/dtype consistently across Gemma model, projection layer, and connectors
- Remove forced `out.float()` in encode_token_weights to prevent downgrading to fp32 after projection
- This allows SageAttention's optimized kernel to run instead of falling back to PyTorch attention

Fixes the warning:
"Error running sage attention: Input tensors must be in dtype of torch.float16 or torch.bfloat16, using pytorch attention instead."
This commit is contained in:
jaystack.dev 2026-01-15 14:44:03 -05:00 committed by GitHub
parent 12918a5f78
commit d22cc94f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,11 +61,26 @@ class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
exec_device = device
if exec_device == "cpu":
exec_device = comfy.model_management.get_torch_device()
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
if dtype_llama is None and dtype is None:
if comfy.model_management.should_use_bf16(exec_device):
dtype_llama = torch.bfloat16
dtype = torch.bfloat16
else:
dtype_llama = torch.float16 if hasattr(comfy.model_management, "should_use_fp16") and comfy.model_management.should_use_fp16(exec_device) else torch.float32
dtype = dtype_llama
elif dtype_llama is None:
dtype_llama = dtype
elif dtype is None:
dtype = dtype_llama
self.dtypes.add(dtype)
self.dtypes.add(dtype_llama)
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
@ -104,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
#out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
@ -118,9 +133,8 @@ class LTXAVTEModel(torch.nn.Module):
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing, unexpected = self.load_state_dict(sdo, strict=False)
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
return (missing, unexpected)
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0