diff --git a/comfy/sd.py b/comfy/sd.py index e6cda5131..e016bea07 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -302,12 +302,14 @@ class ModelPatcher: t = model_sd[k] size += t.nelement() * t.element_size() self.size = size + self.model_keys = set(model_sd.keys()) return size def clone(self): n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) + n.model_keys = self.model_keys return n def set_model_tomesd(self, ratio): @@ -347,17 +349,25 @@ class ModelPatcher: def model_dtype(self): return self.model.get_dtype() - def add_patches(self, patches, strength=1.0): + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): p = {} - model_sd = self.model.state_dict() for k in patches: - if k in model_sd: + if k in self.model_keys: p[k] = patches[k] - self.patches += [(strength, p)] + self.patches += [(strength_patch, p, strength_model)] return p.keys() + def model_state_dict(self, filter_prefix=None): + sd = self.model.state_dict() + keys = list(sd.keys()) + if filter_prefix is not None: + for k in keys: + if not k.startswith(filter_prefix): + sd.pop(k) + return sd + def patch_model(self): - model_sd = self.model.state_dict() + model_sd = self.model_state_dict() for p in self.patches: for k in p[1]: v = p[1][k] @@ -371,8 +381,14 @@ class ModelPatcher: self.backup[key] = weight.clone() alpha = p[0] + strength_model = p[2] - if len(v) == 4: #lora/locon + if strength_model != 1.0: + weight *= strength_model + + if len(v) == 1: + weight += alpha * (v[0]).type(weight.dtype).to(weight.device) + elif len(v) == 4: #lora/locon mat1 = v[0] mat2 = v[1] if v[2] is not None: @@ -428,7 +444,7 @@ class ModelPatcher: weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype).to(weight.device) return self.model def unpatch_model(self): - model_sd = self.model.state_dict() + model_sd = self.model_state_dict() keys = list(self.backup.keys()) for k in keys: model_sd[k][:] = self.backup[k] diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py new file mode 100644 index 000000000..52b73f702 --- /dev/null +++ b/comfy_extras/nodes_model_merging.py @@ -0,0 +1,55 @@ + + +class ModelMergeSimple: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, ratio): + m = model1.clone() + sd = model2.model_state_dict("diffusion_model.") + for k in sd: + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +class ModelMergeBlocks: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model1": ("MODEL",), + "model2": ("MODEL",), + "input": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "middle": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}) + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "merge" + + CATEGORY = "_for_testing/model_merging" + + def merge(self, model1, model2, **kwargs): + m = model1.clone() + sd = model2.model_state_dict("diffusion_model.") + default_ratio = next(iter(kwargs.values())) + + for k in sd: + ratio = default_ratio + k_unet = k[len("diffusion_model."):] + + for arg in kwargs: + if k_unet.startswith(arg): + ratio = kwargs[arg] + + m.add_patches({k: (sd[k], )}, 1.0 - ratio, ratio) + return (m, ) + +NODE_CLASS_MAPPINGS = { + "ModelMergeSimple": ModelMergeSimple, + "ModelMergeBlocks": ModelMergeBlocks +} diff --git a/nodes.py b/nodes.py index b3e7e6006..907c807cf 100644 --- a/nodes.py +++ b/nodes.py @@ -1583,4 +1583,5 @@ def init_custom_nodes(): load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_rebatch.py")) + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_model_merging.py")) load_custom_nodes()