Support quantized Gemma in NewBie dual CLIP

This commit is contained in:
woctordho 2025-12-19 10:44:30 +08:00
parent 37a976a864
commit 91ff032fb5
2 changed files with 15 additions and 5 deletions

View File

@ -1272,7 +1272,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.NewBieClipModel
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer
if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]:
clip_data_gemma = clip_data[0]

View File

@ -21,12 +21,13 @@ class NewBieTokenizer:
def state_dict(self):
return {}
class NewBieClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
class NewBieTEModel(torch.nn.Module):
def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype, model_options=model_options)
dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device)
self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options)
self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype}
self.dtypes = {dtype, dtype_gemma}
def set_clip_options(self, options):
self.gemma.set_clip_options(options)
@ -50,3 +51,12 @@ class NewBieClipModel(torch.nn.Module):
return self.gemma.load_sd(sd)
else:
return self.jina.load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class NewBieTEModel_(NewBieTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return NewBieTEModel_