This commit is contained in:
jaystack.dev 2026-01-18 14:41:27 +08:00 committed by GitHub
commit 1f7ba2f286
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -61,11 +61,26 @@ class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
exec_device = device
if exec_device == "cpu":
exec_device = comfy.model_management.get_torch_device()
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
if dtype_llama is None and dtype is None:
if comfy.model_management.should_use_bf16(exec_device):
dtype_llama = torch.bfloat16
dtype = torch.bfloat16
else:
dtype_llama = torch.float16 if hasattr(comfy.model_management, "should_use_fp16") and comfy.model_management.should_use_fp16(exec_device) else torch.float32
dtype = dtype_llama
elif dtype_llama is None:
dtype_llama = dtype
elif dtype is None:
dtype = dtype_llama
self.dtypes.add(dtype)
self.dtypes.add(dtype_llama)
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
@ -104,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
#out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
@ -118,9 +133,8 @@ class LTXAVTEModel(torch.nn.Module):
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
missing, unexpected = self.load_state_dict(sdo, strict=False)
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
return (missing, unexpected)
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0