clip: support assign load when taking clip from a ckpt

This commit is contained in:
Rattus 2026-01-15 12:43:10 +10:00
parent 0eff43261b
commit 5dcd043d19
3 changed files with 12 additions and 2 deletions

View File

@ -290,6 +290,16 @@ class CLIP:
if full_model:
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
else:
can_assign = self.patcher.is_dynamic()
self.cond_stage_model.can_assign_sd = can_assign
# The CLIP models are a pretty complex web of wrappers and its
# a bit of an API change to plumb this all the way through.
# So spray paint the model with this flag that the loading
# nn.Module can then inspect for itself.
for m in self.cond_stage_model.modules():
m.can_assign_sd = can_assign
return self.cond_stage_model.load_sd(sd)
def get_sd(self):

View File

@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return self(tokens)
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False)
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def parse_parentheses(string):
result = []

View File

@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
if len(sdo) == 0:
sdo = sd
return self.load_state_dict(sdo, strict=False)
return self.load_state_dict(sdo, strict=False, assign=getattr(self, "can_assign_sd", False))
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0