mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 11:32:58 +08:00
Since model saving is not implemented for the mixed quant system this breaks model saving for every scaled fp8 model which needs to be fixed before this gets merged.
38 lines
2.2 KiB
Python
38 lines
2.2 KiB
Python
from comfy import sd1_clip
|
|
from .spiece_tokenizer import SPieceTokenizer
|
|
import comfy.text_encoders.t5
|
|
import os
|
|
|
|
class UMT5XXlModel(sd1_clip.SDClipModel):
|
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
|
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
|
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
|
|
|
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
|
|
|
def state_dict(self):
|
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
|
|
|
|
|
class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
|
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
|
|
|
|
class WanT5Model(sd1_clip.SD1ClipModel):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
|
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
|
|
|
def te(dtype_t5=None, t5_quantization_metadata=None):
|
|
class WanTEModel(WanT5Model):
|
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
|
if t5_quantization_metadata is not None:
|
|
model_options = model_options.copy()
|
|
model_options["quantization_metadata"] = t5_quantization_metadata
|
|
if dtype_t5 is not None:
|
|
dtype = dtype_t5
|
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
|
return WanTEModel
|