mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 16:20:17 +08:00
clip: support assign load when taking clip from a ckpt
This commit is contained in:
parent
0eff43261b
commit
5dcd043d19
10
comfy/sd.py
10
comfy/sd.py
@ -290,6 +290,16 @@ class CLIP:
|
|||||||
if full_model:
|
if full_model:
|
||||||
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
else:
|
else:
|
||||||
|
can_assign = self.patcher.is_dynamic()
|
||||||
|
self.cond_stage_model.can_assign_sd = can_assign
|
||||||
|
|
||||||
|
# The CLIP models are a pretty complex web of wrappers and its
|
||||||
|
# a bit of an API change to plumb this all the way through.
|
||||||
|
# So spray paint the model with this flag that the loading
|
||||||
|
# nn.Module can then inspect for itself.
|
||||||
|
for m in self.cond_stage_model.modules():
|
||||||
|
m.can_assign_sd = can_assign
|
||||||
|
|
||||||
return self.cond_stage_model.load_sd(sd)
|
return self.cond_stage_model.load_sd(sd)
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
|
|||||||
@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
return self(tokens)
|
return self(tokens)
|
||||||
|
|
||||||
def load_sd(self, sd):
|
def load_sd(self, sd):
|
||||||
return self.transformer.load_state_dict(sd, strict=False)
|
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def parse_parentheses(string):
|
def parse_parentheses(string):
|
||||||
result = []
|
result = []
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
return self.load_state_dict(sdo, strict=False)
|
return self.load_state_dict(sdo, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||||
|
|
||||||
def memory_estimation_function(self, token_weight_pairs, device=None):
|
def memory_estimation_function(self, token_weight_pairs, device=None):
|
||||||
constant = 6.0
|
constant = 6.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user