diff --git a/README.md b/README.md index bf16006bf..5b6346a67 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs and CLIP models. - Embeddings/Textual inversion - [Loras (regular, locon and loha)](https://comfyanonymous.github.io/ComfyUI_examples/lora/) +- [Hypernetworks](https://comfyanonymous.github.io/ComfyUI_examples/hypernetworks/) - Loading full workflows (with seeds) from generated PNG files. - Saving/Loading workflows as Json files. - Nodes interface can be used to create complex workflows like one for [Hires fix](https://comfyanonymous.github.io/ComfyUI_examples/2_pass_txt2img/) or much more advanced ones. diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index c27d032a3..ce7180d91 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -163,13 +163,17 @@ class CrossAttentionBirchSan(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads query = self.to_q(x) context = default(context, x) key = self.to_k(context) - value = self.to_v(context) + if value is not None: + value = self.to_v(value) + else: + value = self.to_v(context) + del context, x query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) @@ -256,13 +260,17 @@ class CrossAttentionDoggettx(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q_in = self.to_q(x) context = default(context, x) k_in = self.to_k(context) - v_in = self.to_v(context) + if value is not None: + v_in = self.to_v(value) + del value + else: + v_in = self.to_v(context) del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) @@ -350,13 +358,17 @@ class CrossAttention(nn.Module): nn.Dropout(dropout) ) - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) @@ -402,11 +414,15 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -447,11 +463,15 @@ class CrossAttentionPytorch(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) self.attention_op: Optional[Any] = None - def forward(self, x, context=None, mask=None): + def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) - v = self.to_v(context) + if value is not None: + v = self.to_v(value) + del value + else: + v = self.to_v(context) b, _, _ = q.shape q, k, v = map( @@ -512,11 +532,25 @@ class BasicTransformerBlock(nn.Module): transformer_patches = {} n = self.norm1(x) + if self.disable_self_attn: + context_attn1 = context + else: + context_attn1 = None + value_attn1 = None + + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + if context_attn1 is None: + context_attn1 = n + value_attn1 = context_attn1 + for p in patch: + n, context_attn1, value_attn1 = p(current_index, n, context_attn1, value_attn1) + if "tomesd" in transformer_options: m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"]) - n = u(self.attn1(m(n), context=context if self.disable_self_attn else None)) + n = u(self.attn1(m(n), context=context_attn1, value=value_attn1)) else: - n = self.attn1(n, context=context if self.disable_self_attn else None) + n = self.attn1(n, context=context_attn1, value=value_attn1) x += n if "middle_patch" in transformer_patches: @@ -525,7 +559,16 @@ class BasicTransformerBlock(nn.Module): x = p(current_index, x) n = self.norm2(x) - n = self.attn2(n, context=context) + + context_attn2 = context + value_attn2 = None + if "attn2_patch" in transformer_patches: + patch = transformer_patches["attn2_patch"] + value_attn2 = context_attn2 + for p in patch: + n, context_attn2, value_attn2 = p(current_index, n, context_attn2, value_attn2) + + n = self.attn2(n, context=context_attn2, value=value_attn2) x += n x = self.ff(self.norm3(x)) + x diff --git a/comfy/model_management.py b/comfy/model_management.py index a0d1313d2..6e3a03530 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -133,6 +133,7 @@ def unload_model(): #never unload models from GPU on high vram if vram_state != VRAMState.HIGH_VRAM: current_loaded_model.model.cpu() + current_loaded_model.model_patches_to("cpu") current_loaded_model.unpatch_model() current_loaded_model = None @@ -156,6 +157,8 @@ def load_model_gpu(model): except Exception as e: model.unpatch_model() raise e + + model.model_patches_to(get_torch_device()) current_loaded_model = model if vram_state == VRAMState.CPU: pass diff --git a/comfy/samplers.py b/comfy/samplers.py index 15527224e..b860f25f1 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -197,7 +197,15 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con transformer_options = model_options['transformer_options'].copy() if patches is not None: - transformer_options["patches"] = patches + if "patches" in transformer_options: + cur_patches = transformer_options["patches"].copy() + for p in patches: + if p in cur_patches: + cur_patches[p] = cur_patches[p] + patches[p] + else: + cur_patches[p] = patches[p] + else: + transformer_options["patches"] = patches c['transformer_options'] = transformer_options diff --git a/comfy/sd.py b/comfy/sd.py index 211acd70e..92dbb931d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -254,6 +254,29 @@ class ModelPatcher: def set_model_sampler_cfg_function(self, sampler_cfg_function): self.model_options["sampler_cfg_function"] = sampler_cfg_function + + def set_model_patch(self, patch, name): + to = self.model_options["transformer_options"] + if "patches" not in to: + to["patches"] = {} + to["patches"][name] = to["patches"].get(name, []) + [patch] + + def set_model_attn1_patch(self, patch): + self.set_model_patch(patch, "attn1_patch") + + def set_model_attn2_patch(self, patch): + self.set_model_patch(patch, "attn2_patch") + + def model_patches_to(self, device): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], "to"): + patch_list[i] = patch_list[i].to(device) + def model_dtype(self): return self.model.diffusion_model.dtype diff --git a/comfy/utils.py b/comfy/utils.py index 0380b91dd..68f93403c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,11 +1,14 @@ import torch -def load_torch_file(ckpt): +def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: - pl_sd = torch.load(ckpt, map_location="cpu") + if safe_load: + pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) + else: + pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") if "state_dict" in pl_sd: diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py new file mode 100644 index 000000000..0c7250e43 --- /dev/null +++ b/comfy_extras/nodes_hypernetwork.py @@ -0,0 +1,109 @@ +import comfy.utils +import folder_paths +import torch + +def load_hypernetwork_patch(path, strength): + sd = comfy.utils.load_torch_file(path, safe_load=True) + activation_func = sd.get('activation_func', 'linear') + is_layer_norm = sd.get('is_layer_norm', False) + use_dropout = sd.get('use_dropout', False) + activate_output = sd.get('activate_output', False) + last_layer_dropout = sd.get('last_layer_dropout', False) + + valid_activation = { + "linear": torch.nn.Identity, + "relu": torch.nn.ReLU, + "leakyrelu": torch.nn.LeakyReLU, + "elu": torch.nn.ELU, + "swish": torch.nn.Hardswish, + "tanh": torch.nn.Tanh, + "sigmoid": torch.nn.Sigmoid, + } + + if activation_func not in valid_activation: + print("Unsupported Hypernetwork format, if you report it I might implement it.", path, " ", activation_func, is_layer_norm, use_dropout, activate_output, last_layer_dropout) + return None + + out = {} + + for d in sd: + try: + dim = int(d) + except: + continue + + output = [] + for index in [0, 1]: + attn_weights = sd[dim][index] + keys = attn_weights.keys() + + linears = filter(lambda a: a.endswith(".weight"), keys) + linears = list(map(lambda a: a[:-len(".weight")], linears)) + layers = [] + + for i in range(len(linears)): + lin_name = linears[i] + last_layer = (i == (len(linears) - 1)) + penultimate_layer = (i == (len(linears) - 2)) + + lin_weight = attn_weights['{}.weight'.format(lin_name)] + lin_bias = attn_weights['{}.bias'.format(lin_name)] + layer = torch.nn.Linear(lin_weight.shape[1], lin_weight.shape[0]) + layer.load_state_dict({"weight": lin_weight, "bias": lin_bias}) + layers.append(layer) + if activation_func != "linear": + if (not last_layer) or (activate_output): + layers.append(valid_activation[activation_func]()) + if is_layer_norm: + layers.append(torch.nn.LayerNorm(lin_weight.shape[0])) + if use_dropout: + if (not last_layer) and (not penultimate_layer or last_layer_dropout): + layers.append(torch.nn.Dropout(p=0.3)) + + output.append(torch.nn.Sequential(*layers)) + out[dim] = torch.nn.ModuleList(output) + + class hypernetwork_patch: + def __init__(self, hypernet, strength): + self.hypernet = hypernet + self.strength = strength + def __call__(self, current_index, q, k, v): + dim = k.shape[-1] + if dim in self.hypernet: + hn = self.hypernet[dim] + k = k + hn[0](k) * self.strength + v = v + hn[1](v) * self.strength + + return q, k, v + + def to(self, device): + for d in self.hypernet.keys(): + self.hypernet[d] = self.hypernet[d].to(device) + return self + + return hypernetwork_patch(out, strength) + +class HypernetworkLoader: + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), + "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_hypernetwork" + + CATEGORY = "loaders" + + def load_hypernetwork(self, model, hypernetwork_name, strength): + hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) + model_hypernetwork = model.clone() + patch = load_hypernetwork_patch(hypernetwork_path, strength) + if patch is not None: + model_hypernetwork.set_model_attn1_patch(patch) + model_hypernetwork.set_model_attn2_patch(patch) + return (model_hypernetwork,) + +NODE_CLASS_MAPPINGS = { + "HypernetworkLoader": HypernetworkLoader +} diff --git a/execution.py b/execution.py index b062deeb1..31a208e78 100644 --- a/execution.py +++ b/execution.py @@ -11,7 +11,6 @@ import torch import nodes import comfy.model_management -import folder_paths def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_data={}): valid_inputs = class_def.INPUT_TYPES() @@ -41,15 +40,13 @@ def get_input_data(inputs, class_def, unique_id, outputs={}, prompt={}, extra_da input_data_all[x] = unique_id return input_data_all -def recursive_execute(server, prompt, outputs, current_item, extra_data={}): +def recursive_execute(server, prompt, outputs, current_item, extra_data, executed): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return [] - - executed = [] + return for x in inputs: input_data = inputs[x] @@ -58,7 +55,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - executed += recursive_execute(server, prompt, outputs, input_unique_id, extra_data) + recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed) input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: @@ -73,7 +70,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data={}): server.send_sync("executed", { "node": unique_id, "output": outputs[unique_id]["ui"] }, server.client_id) if "result" in outputs[unique_id]: outputs[unique_id] = outputs[unique_id]["result"] - return executed + [unique_id] + executed.add(unique_id) def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item @@ -159,7 +156,7 @@ class PromptExecutor: recursive_output_delete_if_changed(prompt, self.old_prompt, self.outputs, x) current_outputs = set(self.outputs.keys()) - executed = [] + executed = set() try: to_execute = [] for x in prompt: @@ -182,12 +179,12 @@ class PromptExecutor: except: valid = False if valid: - executed += recursive_execute(self.server, prompt, self.outputs, x, extra_data) + recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed) except Exception as e: print(traceback.format_exc()) to_delete = [] for o in self.outputs: - if o not in current_outputs: + if (o not in current_outputs) and (o not in executed): to_delete += [o] if o in self.old_prompt: d = self.old_prompt.pop(o) @@ -195,11 +192,9 @@ class PromptExecutor: for o in to_delete: d = self.outputs.pop(o) del d - else: - executed = set(executed) + finally: for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) - finally: self.server.last_node_id = None if self.server.client_id is not None: self.server.send_sync("executing", { "node": None }, self.server.client_id) @@ -250,14 +245,15 @@ def validate_inputs(prompt, item): if "max" in info[1] and val > info[1]["max"]: return (False, "Value bigger than max. {}, {}".format(class_type, x)) - if isinstance(type_input, list): - is_annotated_path = val.endswith("[temp]") or val.endswith("[input]") or val.endswith("[output]") - if is_annotated_path: - if not folder_paths.exists_annotated_filepath(val): - return (False, "Invalid file path. {}, {}: {}".format(class_type, x, val)) - - elif val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + if hasattr(obj_class, "VALIDATE_INPUTS"): + input_data_all = get_input_data(inputs, obj_class, unique_id) + ret = obj_class.VALIDATE_INPUTS(**input_data_all) + if ret != True: + return (False, "{}, {}".format(class_type, ret)) + else: + if isinstance(type_input, list): + if val not in type_input: + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): @@ -279,7 +275,8 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o) valid = m[0] reason = m[1] - except: + except Exception as e: + print(traceback.format_exc()) valid = False reason = "Parsing error" diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index ac1ffe9d2..fa5418a68 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -13,6 +13,7 @@ a111: models/ESRGAN models/SwinIR embeddings: embeddings + hypernetworks: models/hypernetworks controlnet: models/ControlNet #other_ui: diff --git a/folder_paths.py b/folder_paths.py index 8d24e3ed6..e5b89492c 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -32,6 +32,7 @@ folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_m folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], []) +folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @@ -70,7 +71,7 @@ def get_directory_by_type(type_name): # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # otherwise use default_path as base_dir -def touch_annotated_filepath(name): +def annotated_filepath(name): if name.endswith("[output]"): base_dir = get_output_directory() name = name[:-9] @@ -87,7 +88,7 @@ def touch_annotated_filepath(name): def get_annotated_filepath(name, default_dir=None): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: if default_dir is not None: @@ -99,7 +100,7 @@ def get_annotated_filepath(name, default_dir=None): def exists_annotated_filepath(name): - name, base_dir = touch_annotated_filepath(name) + name, base_dir = annotated_filepath(name) if base_dir is None: base_dir = get_input_directory() # fallback path diff --git a/models/hypernetworks/put_hypernetworks_here b/models/hypernetworks/put_hypernetworks_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 173eb32c9..d1133d1d8 100644 --- a/nodes.py +++ b/nodes.py @@ -974,8 +974,7 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" def load_image(self, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -989,20 +988,27 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + return True + class LoadImageMask: + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() return {"required": {"image": (sorted(os.listdir(input_dir)), ), - "channel": (["alpha", "red", "green", "blue"], ),} + "channel": (s._color_channels, ),} } CATEGORY = "mask" @@ -1010,8 +1016,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) i = Image.open(image_path) if i.getbands() != ("R", "G", "B", "A"): i = i.convert("RGBA") @@ -1028,13 +1033,22 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - input_dir = folder_paths.get_input_directory() - image_path = folder_paths.get_annotated_filepath(image, input_dir) + image_path = folder_paths.get_annotated_filepath(image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) return m.digest().hex() + @classmethod + def VALIDATE_INPUTS(s, image, channel): + if not folder_paths.exists_annotated_filepath(image): + return "Invalid image file: {}".format(image) + + if channel not in s._color_channels: + return "Invalid color channel: {}".format(channel) + + return True + class ImageScale: upscale_methods = ["nearest-exact", "bilinear", "area"] crop_methods = ["disabled", "center"] @@ -1268,6 +1282,7 @@ def load_custom_nodes(): def init_custom_nodes(): load_custom_nodes() + load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_hypernetwork.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_upscale_model.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_post_processing.py")) load_custom_node(os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_extras"), "nodes_mask.py")) diff --git a/web/scripts/api.js b/web/scripts/api.js index 2b90c2abc..d29faa5ba 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -35,7 +35,7 @@ class ComfyApi extends EventTarget { } let opened = false; - let existingSession = sessionStorage["Comfy.SessionId"] || ""; + let existingSession = window.name; if (existingSession) { existingSession = "?clientId=" + existingSession; } @@ -75,7 +75,7 @@ class ComfyApi extends EventTarget { case "status": if (msg.data.sid) { this.clientId = msg.data.sid; - sessionStorage["Comfy.SessionId"] = this.clientId; + window.name = this.clientId; } this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); break; diff --git a/web/scripts/app.js b/web/scripts/app.js index b3e88d46f..3bda52cdd 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -145,7 +145,7 @@ export class ComfyApp { if(this.widgets) { widgets = this.widgets.map(({ type, name, value }) => ({ type, name, value })); } - + let img = new Image(); var imgs = undefined; if(this.imgs != undefined) { @@ -172,7 +172,7 @@ export class ComfyApp { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(this.widgets).find(obj => obj.type === type && obj.name === name); if (prop) { - prop.value = value; + prop.callback(value); } }); }