From a310ca93d3ec582be646f3b919af647a140e357a Mon Sep 17 00:00:00 2001 From: Rattus Date: Thu, 15 Jan 2026 12:43:10 +1000 Subject: [PATCH] clip: support assign load when taking clip from a ckpt --- comfy/sd.py | 10 ++++++++++ comfy/sd1_clip.py | 2 +- comfy/text_encoders/lt.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 8745793ef..42b2fd6df 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -291,6 +291,16 @@ class CLIP: if full_model: return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..b9380e021 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e49161964..26573fb12 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -125,7 +125,7 @@ class LTXAVTEModel(torch.nn.Module): for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} if component_sd: - missing, unexpected = component.load_state_dict(component_sd, strict=False) + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) missing_all.extend([f"{prefix}{k}" for k in missing]) unexpected_all.extend([f"{prefix}{k}" for k in unexpected])