diff --git a/comfy/sd.py b/comfy/sd.py index 12f1373c5..2acb50e53 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -290,6 +290,16 @@ class CLIP: if full_model: return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..b9380e021 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 776e25e97..145c4e2ac 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module): if len(sdo) == 0: sdo = sd - return self.load_state_dict(sdo, strict=False) + return self.load_state_dict(sdo, strict=False, assign=getattr(self, "can_assign_sd", False)) def memory_estimation_function(self, token_weight_pairs, device=None): constant = 6.0