feat: LoraLoaderBlockWeights

Lora Block Weights
This commit is contained in:
Dr.Lt.Data 2023-03-28 22:33:57 +09:00
parent 31dd6c0531
commit bd0a8163eb
3 changed files with 103 additions and 9 deletions

View File

@ -280,21 +280,31 @@ class ModelPatcher:
n.patches = self.patches[:]
return n
def add_patches(self, patches, strength=1.0):
def add_patches(self, patches, strength=1.0, block_weights={}):
p = {}
model_sd = self.model.state_dict()
for k in patches:
if k in model_sd:
p[k] = patches[k]
self.patches += [(strength, p)]
sk = k.split(".")
block_key = ".".join(sk[2:4])
if block_weights.__contains__(block_key):
# apply block weights
p[k] = (strength * block_weights[block_key], patches[k])
else:
# apply only base strength
p[k] = (strength, patches[k])
self.patches += [p]
return p.keys()
def patch_model(self):
model_sd = self.model.state_dict()
for p in self.patches:
for k in p[1]:
v = p[1][k]
for k in p:
v = p[k][1]
key = k
if key not in model_sd:
print("could not patch. key doesn't exist in model:", k)
continue
@ -303,7 +313,14 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = weight.clone()
alpha = p[0]
alpha = p[k][0]
if key.startswith("model.diffusion_model."):
print(f"{key}: {alpha}")
# sk = key.split(".")
# block_key = ".".join(sk[2:4])
# if LORA_BLOCK_WEIGHTS.__contains__(block_key):
# alpha *= LORA_BLOCK_WEIGHTS[block_key]
if len(v) == 4: #lora/locon
mat1 = v[0]
@ -342,12 +359,12 @@ class ModelPatcher:
self.backup = {}
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, block_weights):
key_map = model_lora_keys(model.model)
key_map = model_lora_keys(clip.cond_stage_model, key_map)
loaded = load_lora(lora_path, key_map)
new_modelpatcher = model.clone()
k = new_modelpatcher.add_patches(loaded, strength_model)
k = new_modelpatcher.add_patches(loaded, strength_model, block_weights)
new_clip = clip.clone()
k1 = new_clip.add_patches(loaded, strength_clip)
k = set(k)

View File

@ -251,7 +251,83 @@ class LoraLoader:
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
lora_path = folder_paths.get_full_path("loras", lora_name)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
return (model_lora, clip_lora)
class LoraLoaderBlockWeights:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"clip": ("CLIP", ),
"lora_name": (folder_paths.get_filename_list("loras"), ),
"strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"in0_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in1": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in2": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in3_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in4": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in5": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in6_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in7": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in8": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in9_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in10_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"in11_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"mid": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out0_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out1_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out2_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out3": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out4": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out5": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out6": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out7": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out8": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out9": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out10": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
"out11": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01})
}}
RETURN_TYPES = ("MODEL", "CLIP")
FUNCTION = "load_lora"
CATEGORY = "loaders"
def load_lora(self, model, clip, lora_name, strength_model, strength_clip,
in0_locon, in1, in2, in3_locon, in4, in5, in6_locon, in7, in8, in9_locon, in10_locon, in11_locon,
mid,
out0_locon, out1_locon, out2_locon, out3, out4, out5, out6, out7, out8, out9, out10, out11):
lora_path = folder_paths.get_full_path("loras", lora_name)
block_weights = {
"input_blocks.0": in0_locon,
"input_blocks.1": in1,
"input_blocks.2": in2,
"input_blocks.3": in3_locon,
"input_blocks.4": in4,
"input_blocks.5": in5,
"input_blocks.6": in6_locon,
"input_blocks.7": in7,
"input_blocks.8": in8,
"input_blocks.9": in9_locon,
"input_blocks.10": in10_locon,
"input_blocks.11": in11_locon,
"middle_block.1": mid,
"output_blocks.0": out0_locon,
"output_blocks.1": out1_locon,
"output_blocks.2": out2_locon,
"output_blocks.3": out3,
"output_blocks.4": out4,
"output_blocks.5": out5,
"output_blocks.6": out6,
"output_blocks.7": out7,
"output_blocks.8": out8,
"output_blocks.9": out9,
"output_blocks.10": out10,
"output_blocks.11": out11
}
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, block_weights)
return (model_lora, clip_lora)
class VAELoader:
@ -1006,6 +1082,7 @@ NODE_CLASS_MAPPINGS = {
"LatentFlip": LatentFlip,
"LatentCrop": LatentCrop,
"LoraLoader": LoraLoader,
"LoraLoaderBlockWeights": LoraLoaderBlockWeights,
"CLIPLoader": CLIPLoader,
"CLIPVisionEncode": CLIPVisionEncode,
"StyleModelApply": StyleModelApply,