Merge branch 'comfyanonymous:master' into fix/secure-combo

This commit is contained in:
Dr.Lt.Data 2023-06-28 15:54:37 +09:00 committed by GitHub
commit f8ee81e9ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 6 deletions

View File

@ -156,10 +156,10 @@ class SDXLRefiner(BaseModel):
print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score) print(clip_pooled.shape, width, height, crop_w, crop_h, aesthetic_score)
out = [] out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([aesthetic_score]))) out.append(self.embedder(torch.Tensor([aesthetic_score])))
flat = torch.flatten(torch.cat(out))[None, ] flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
@ -180,11 +180,11 @@ class SDXL(BaseModel):
print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height) print(clip_pooled.shape, width, height, crop_w, crop_h, target_width, target_height)
out = [] out = []
out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([height])))
out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([width])))
out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_h])))
out.append(self.embedder(torch.Tensor([target_width]))) out.append(self.embedder(torch.Tensor([crop_w])))
out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_height])))
out.append(self.embedder(torch.Tensor([target_width])))
flat = torch.flatten(torch.cat(out))[None, ] flat = torch.flatten(torch.cat(out))[None, ]
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)

View File

@ -223,13 +223,28 @@ def model_lora_keys(model, key_map={}):
counter += 1 counter += 1
counter = 0 counter = 0
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
for b in range(24): clip_l_present = False
for b in range(32):
for c in LORA_CLIP_MAP: for c in LORA_CLIP_MAP:
k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) k = "transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk: if k in sdk:
lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
key_map[lora_key] = k key_map[lora_key] = k
k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
key_map[lora_key] = k
clip_l_present = True
k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
if k in sdk:
if clip_l_present:
lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
else:
lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
key_map[lora_key] = k
#Locon stuff #Locon stuff
ds_counter = 0 ds_counter = 0

View File

@ -148,6 +148,25 @@ class ConditioningSetMask:
c.append(n) c.append(n)
return (c, ) return (c, )
class ConditioningZeroOut:
@classmethod
def INPUT_TYPES(s):
return {"required": {"conditioning": ("CONDITIONING", )}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "zero_out"
CATEGORY = "advanced/conditioning"
def zero_out(self, conditioning):
c = []
for t in conditioning:
d = t[1].copy()
if "pooled_output" in d:
d["pooled_output"] = torch.zeros_like(d["pooled_output"])
n = [torch.zeros_like(t[0]), d]
c.append(n)
return (c, )
class VAEDecode: class VAEDecode:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1350,6 +1369,8 @@ NODE_CLASS_MAPPINGS = {
"LoadLatent": LoadLatent, "LoadLatent": LoadLatent,
"SaveLatent": SaveLatent, "SaveLatent": SaveLatent,
"ConditioningZeroOut": ConditioningZeroOut,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {