mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
Fix ltxav te mem estimation. (#12643)
This commit is contained in:
parent
e14b04478c
commit
7253531670
@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
|
import itertools
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
constant /= 2.0
|
constant /= 2.0
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
|
||||||
num_tokens = max(num_tokens, 64)
|
|
||||||
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
|
||||||
|
num_tokens = max(num_tokens, 642)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user