mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-01 09:10:16 +08:00
clip: support assign load when taking clip from a ckpt
This commit is contained in:
parent
28dd1c4c1f
commit
cb41b22d23
10
comfy/sd.py
10
comfy/sd.py
@ -290,6 +290,16 @@ class CLIP:
|
|||||||
if full_model:
|
if full_model:
|
||||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
else:
|
else:
|
||||||
|
can_assign = self.patcher.is_dynamic()
|
||||||
|
self.cond_stage_model.can_assign_sd = can_assign
|
||||||
|
|
||||||
|
# The CLIP models are a pretty complex web of wrappers and its
|
||||||
|
# a bit of an API change to plumb this all the way through.
|
||||||
|
# So spray paint the model with this flag that the loading
|
||||||
|
# nn.Module can then inspect for itself.
|
||||||
|
for m in self.cond_stage_model.modules():
|
||||||
|
m.can_assign_sd = can_assign
|
||||||
|
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
|
|||||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False)
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
@ -118,7 +118,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
missing, unexpected = self.load_state_dict(sdo, strict=False)
|
missing, unexpected = self.load_state_dict(sdo, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
missing = [k for k in missing if not k.startswith("gemma3_12b.")] # filter out keys that belong to the main gemma model
|
||||||
return (missing, unexpected)
|
return (missing, unexpected)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user