mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 13:20:19 +08:00
feat: LoraLoaderBlockWeights
Lora Block Weights
This commit is contained in:
parent
31dd6c0531
commit
bd0a8163eb
33
comfy/sd.py
33
comfy/sd.py
@ -280,21 +280,31 @@ class ModelPatcher:
|
|||||||
n.patches = self.patches[:]
|
n.patches = self.patches[:]
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def add_patches(self, patches, strength=1.0):
|
def add_patches(self, patches, strength=1.0, block_weights={}):
|
||||||
p = {}
|
p = {}
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for k in patches:
|
for k in patches:
|
||||||
if k in model_sd:
|
if k in model_sd:
|
||||||
p[k] = patches[k]
|
sk = k.split(".")
|
||||||
self.patches += [(strength, p)]
|
block_key = ".".join(sk[2:4])
|
||||||
|
if block_weights.__contains__(block_key):
|
||||||
|
# apply block weights
|
||||||
|
p[k] = (strength * block_weights[block_key], patches[k])
|
||||||
|
else:
|
||||||
|
# apply only base strength
|
||||||
|
p[k] = (strength, patches[k])
|
||||||
|
|
||||||
|
self.patches += [p]
|
||||||
return p.keys()
|
return p.keys()
|
||||||
|
|
||||||
def patch_model(self):
|
def patch_model(self):
|
||||||
model_sd = self.model.state_dict()
|
model_sd = self.model.state_dict()
|
||||||
for p in self.patches:
|
for p in self.patches:
|
||||||
for k in p[1]:
|
|
||||||
v = p[1][k]
|
for k in p:
|
||||||
|
v = p[k][1]
|
||||||
key = k
|
key = k
|
||||||
|
|
||||||
if key not in model_sd:
|
if key not in model_sd:
|
||||||
print("could not patch. key doesn't exist in model:", k)
|
print("could not patch. key doesn't exist in model:", k)
|
||||||
continue
|
continue
|
||||||
@ -303,7 +313,14 @@ class ModelPatcher:
|
|||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.clone()
|
self.backup[key] = weight.clone()
|
||||||
|
|
||||||
alpha = p[0]
|
alpha = p[k][0]
|
||||||
|
|
||||||
|
if key.startswith("model.diffusion_model."):
|
||||||
|
print(f"{key}: {alpha}")
|
||||||
|
# sk = key.split(".")
|
||||||
|
# block_key = ".".join(sk[2:4])
|
||||||
|
# if LORA_BLOCK_WEIGHTS.__contains__(block_key):
|
||||||
|
# alpha *= LORA_BLOCK_WEIGHTS[block_key]
|
||||||
|
|
||||||
if len(v) == 4: #lora/locon
|
if len(v) == 4: #lora/locon
|
||||||
mat1 = v[0]
|
mat1 = v[0]
|
||||||
@ -342,12 +359,12 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.backup = {}
|
self.backup = {}
|
||||||
|
|
||||||
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip):
|
def load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, block_weights):
|
||||||
key_map = model_lora_keys(model.model)
|
key_map = model_lora_keys(model.model)
|
||||||
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
key_map = model_lora_keys(clip.cond_stage_model, key_map)
|
||||||
loaded = load_lora(lora_path, key_map)
|
loaded = load_lora(lora_path, key_map)
|
||||||
new_modelpatcher = model.clone()
|
new_modelpatcher = model.clone()
|
||||||
k = new_modelpatcher.add_patches(loaded, strength_model)
|
k = new_modelpatcher.add_patches(loaded, strength_model, block_weights)
|
||||||
new_clip = clip.clone()
|
new_clip = clip.clone()
|
||||||
k1 = new_clip.add_patches(loaded, strength_clip)
|
k1 = new_clip.add_patches(loaded, strength_clip)
|
||||||
k = set(k)
|
k = set(k)
|
||||||
|
|||||||
79
nodes.py
79
nodes.py
@ -251,7 +251,83 @@ class LoraLoader:
|
|||||||
|
|
||||||
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
|
||||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip)
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, {})
|
||||||
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraLoaderBlockWeights:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"clip": ("CLIP", ),
|
||||||
|
"lora_name": (folder_paths.get_filename_list("loras"), ),
|
||||||
|
"strength_model": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"strength_clip": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"in0_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in1": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in2": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in3_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in4": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in5": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in6_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in7": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in8": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in9_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in10_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"in11_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"mid": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out0_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out1_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out2_locon": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out3": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out4": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out5": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out6": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out7": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out8": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out9": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out10": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"out11": ("FLOAT", {"default": 0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
}}
|
||||||
|
RETURN_TYPES = ("MODEL", "CLIP")
|
||||||
|
FUNCTION = "load_lora"
|
||||||
|
|
||||||
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
|
def load_lora(self, model, clip, lora_name, strength_model, strength_clip,
|
||||||
|
in0_locon, in1, in2, in3_locon, in4, in5, in6_locon, in7, in8, in9_locon, in10_locon, in11_locon,
|
||||||
|
mid,
|
||||||
|
out0_locon, out1_locon, out2_locon, out3, out4, out5, out6, out7, out8, out9, out10, out11):
|
||||||
|
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||||
|
block_weights = {
|
||||||
|
"input_blocks.0": in0_locon,
|
||||||
|
"input_blocks.1": in1,
|
||||||
|
"input_blocks.2": in2,
|
||||||
|
"input_blocks.3": in3_locon,
|
||||||
|
"input_blocks.4": in4,
|
||||||
|
"input_blocks.5": in5,
|
||||||
|
"input_blocks.6": in6_locon,
|
||||||
|
"input_blocks.7": in7,
|
||||||
|
"input_blocks.8": in8,
|
||||||
|
"input_blocks.9": in9_locon,
|
||||||
|
"input_blocks.10": in10_locon,
|
||||||
|
"input_blocks.11": in11_locon,
|
||||||
|
"middle_block.1": mid,
|
||||||
|
"output_blocks.0": out0_locon,
|
||||||
|
"output_blocks.1": out1_locon,
|
||||||
|
"output_blocks.2": out2_locon,
|
||||||
|
"output_blocks.3": out3,
|
||||||
|
"output_blocks.4": out4,
|
||||||
|
"output_blocks.5": out5,
|
||||||
|
"output_blocks.6": out6,
|
||||||
|
"output_blocks.7": out7,
|
||||||
|
"output_blocks.8": out8,
|
||||||
|
"output_blocks.9": out9,
|
||||||
|
"output_blocks.10": out10,
|
||||||
|
"output_blocks.11": out11
|
||||||
|
}
|
||||||
|
|
||||||
|
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip, block_weights)
|
||||||
return (model_lora, clip_lora)
|
return (model_lora, clip_lora)
|
||||||
|
|
||||||
class VAELoader:
|
class VAELoader:
|
||||||
@ -1006,6 +1082,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"LatentFlip": LatentFlip,
|
"LatentFlip": LatentFlip,
|
||||||
"LatentCrop": LatentCrop,
|
"LatentCrop": LatentCrop,
|
||||||
"LoraLoader": LoraLoader,
|
"LoraLoader": LoraLoader,
|
||||||
|
"LoraLoaderBlockWeights": LoraLoaderBlockWeights,
|
||||||
"CLIPLoader": CLIPLoader,
|
"CLIPLoader": CLIPLoader,
|
||||||
"CLIPVisionEncode": CLIPVisionEncode,
|
"CLIPVisionEncode": CLIPVisionEncode,
|
||||||
"StyleModelApply": StyleModelApply,
|
"StyleModelApply": StyleModelApply,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user