diff --git a/comfy/sd.py b/comfy/sd.py index b8e2d0951..7de7dd9c6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1272,7 +1272,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.NEWBIE: - clip_target.clip = comfy.text_encoders.newbie.NewBieClipModel + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: clip_data_gemma = clip_data[0] diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py index 5cc1b5926..31904462b 100644 --- a/comfy/text_encoders/newbie.py +++ b/comfy/text_encoders/newbie.py @@ -21,12 +21,13 @@ class NewBieTokenizer: def state_dict(self): return {} -class NewBieClipModel(torch.nn.Module): - def __init__(self, device="cpu", dtype=None, model_options={}): +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): super().__init__() - self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype, model_options=model_options) + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) - self.dtypes = {dtype} + self.dtypes = {dtype, dtype_gemma} def set_clip_options(self, options): self.gemma.set_clip_options(options) @@ -50,3 +51,12 @@ class NewBieClipModel(torch.nn.Module): return self.gemma.load_sd(sd) else: return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_