mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Merge d22cc94f53 into 034fac7054
This commit is contained in:
commit
1f7ba2f286
@ -61,11 +61,26 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
self.dtypes.add(dtype)
|
exec_device = device
|
||||||
|
if exec_device == "cpu":
|
||||||
|
exec_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
if dtype_llama is None and dtype is None:
|
||||||
|
if comfy.model_management.should_use_bf16(exec_device):
|
||||||
|
dtype_llama = torch.bfloat16
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
dtype_llama = torch.float16 if hasattr(comfy.model_management, "should_use_fp16") and comfy.model_management.should_use_fp16(exec_device) else torch.float32
|
||||||
|
dtype = dtype_llama
|
||||||
|
elif dtype_llama is None:
|
||||||
|
dtype_llama = dtype
|
||||||
|
elif dtype is None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
|
||||||
|
self.dtypes.add(dtype)
|
||||||
self.dtypes.add(dtype_llama)
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
operations = self.gemma3_12b.operations # TODO
|
operations = self.gemma3_12b.operations # TODO
|
||||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -104,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
out = self.text_embedding_projection(out)
|
out = self.text_embedding_projection(out)
|
||||||
out = out.float()
|
#out = out.float()
|
||||||
out_vid = self.video_embeddings_connector(out)[0]
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
out_audio = self.audio_embeddings_connector(out)[0]
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
@ -118,9 +133,8 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
missing, unexpected = self.load_state_dict(sdo, strict=False)
|
|
||||||
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
return self.load_state_dict(sdo, strict=False)
|
||||||
return (missing, unexpected)
|
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
constant = 6.0
|
constant = 6.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user