mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Merge d22cc94f53 into 034fac7054
This commit is contained in:
commit
1f7ba2f286
@ -61,11 +61,26 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
exec_device = device
|
||||
if exec_device == "cpu":
|
||||
exec_device = comfy.model_management.get_torch_device()
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
if dtype_llama is None and dtype is None:
|
||||
if comfy.model_management.should_use_bf16(exec_device):
|
||||
dtype_llama = torch.bfloat16
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
dtype_llama = torch.float16 if hasattr(comfy.model_management, "should_use_fp16") and comfy.model_management.should_use_fp16(exec_device) else torch.float32
|
||||
dtype = dtype_llama
|
||||
elif dtype_llama is None:
|
||||
dtype_llama = dtype
|
||||
elif dtype is None:
|
||||
dtype = dtype_llama
|
||||
|
||||
self.dtypes.add(dtype)
|
||||
self.dtypes.add(dtype_llama)
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
@ -104,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
#out = out.float()
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
@ -118,9 +133,8 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
missing, unexpected = self.load_state_dict(sdo, strict=False)
|
||||
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
||||
return (missing, unexpected)
|
||||
|
||||
return self.load_state_dict(sdo, strict=False)
|
||||
|
||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||
constant = 6.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user