From 9b1396e93a19748dd4c4bb35637638bb0f91b5f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 24 May 2023 14:01:11 -0400 Subject: [PATCH 01/82] Fix issue importing other ui prompts. --- web/scripts/pnginfo.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 8ddb7a1c5..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); From 8b4b0c3188110e1faa8865570637172ab4b60ba1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 25 May 2023 19:23:47 +0200 Subject: [PATCH 02/82] vecorized bislerp --- comfy/utils.py | 117 +++++++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..cc0e5069a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -46,71 +47,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) - out1[:,:,y_dest,x_dest] = out_value - return out1 + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + + coords_1 = coords_1.expand((n, c, h_new, -1)) + coords_2 = coords_2.expand((n, c, h_new, -1)) + ratios = ratios.expand((n, 1, h_new, -1)) + + pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": From e1278fa925cf59350bae76dc3d0c59a0e9564789 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 13:30:59 -0400 Subject: [PATCH 03/82] Support old pytorch versions that don't have weights_only. --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..d58320b4a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -6,6 +6,10 @@ def load_torch_file(ckpt, safe_load=False): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: From 87ab25fac77ff1d558fea3c02733a463cb1fa013 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:31:27 -0400 Subject: [PATCH 04/82] Do operations in same order as the one it replaces. --- comfy/utils.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 33c1c3dd7..f139fbb27 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -98,29 +98,27 @@ def bislerp(samples, width, height): n,c,h,w = samples.shape h_new, w_new = (height, width) - #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) - coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) - coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) - ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - - pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) - #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - coords_1 = coords_1.expand((n, c, h_new, -1)) - coords_2 = coords_2.expand((n, c, h_new, -1)) - ratios = ratios.expand((n, 1, h_new, -1)) - - pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) From eb4bd7711acec9a2a2d4f1d4dcc1d32e1236c976 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:42:56 -0400 Subject: [PATCH 05/82] Remove einops. --- comfy/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f139fbb27..5ed9aaa02 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,5 @@ import torch import math -import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -104,12 +103,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) - pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) @@ -117,12 +116,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result def common_upscale(samples, width, height, upscale_method, crop): From 4d1ed829d9a934d9a303a725e325f90934854ac8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 19:33:30 -0500 Subject: [PATCH 06/82] Don't load some model types if weight is zero --- nodes.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nodes.py b/nodes.py index f0a93ebd5..68010f040 100644 --- a/nodes.py +++ b/nodes.py @@ -426,6 +426,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) @@ -507,6 +510,9 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) for t in conditioning: @@ -613,6 +619,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() From 679bd2845af8e22b2802cf326b99b40a26ba7811 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 May 2023 21:46:11 -0400 Subject: [PATCH 07/82] Safetensors isn't optional anymore. --- folder_paths.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 28f117824..20b461c94 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,7 @@ import os -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} From 73e85fb3f4b104053fb1ac5d0aea456e373ea8c8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:00:47 -0500 Subject: [PATCH 08/82] Improve error output for failed nodes --- execution.py | 237 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 33 deletions(-) diff --git a/execution.py b/execution.py index 25f2fcacd..691beb102 100644 --- a/execution.py +++ b/execution.py @@ -297,24 +297,80 @@ def validate_inputs(prompt, item, validated): class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: if type_input == "INT": val = int(val) @@ -328,26 +384,97 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r), unique_id) + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f": {str(r)}" + else: + details += "." + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) - ret = (True, "", unique_id) validated[unique_id] = ret return ret +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + def validate_prompt(prompt): outputs = set() for x in prompt: @@ -356,7 +483,13 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs", [], []) + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] @@ -364,34 +497,72 @@ def validate_prompt(prompt): validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - node_id = m[2] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" - node_id = None + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: + if valid is True: good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - if node_id is not None: - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "", list(good_outputs), node_errors) + error = { + "type": "prompt_no_good_outputs", + "message": "Prompt has no properly connected outputs", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: From cc4d3435d3590288e21f3adfd42f044a7e45fae4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:48:55 -0500 Subject: [PATCH 09/82] Highlight failing nodes/inputs in frontend --- web/scripts/app.js | 74 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..21fe94802 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -771,16 +771,25 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + + self.graphTime = Date.now() if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -807,11 +816,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -1243,6 +1269,31 @@ export class ComfyApp { return { workflow, output }; } + #formatError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1250,8 +1301,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1262,7 +1315,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response.error || error.toString()); + const formattedError = this.#formatError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1360,6 +1418,8 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.lastPromptError = null; + this.graphTime = null } } From c33b7c5549b7b277011e2c3f50215ba466afb205 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:54:13 -0500 Subject: [PATCH 10/82] Improve invalid prompt error message --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 691beb102..66753ff90 100644 --- a/execution.py +++ b/execution.py @@ -554,8 +554,8 @@ def validate_prompt(prompt): errors_list = "\n".join(errors_list) error = { - "type": "prompt_no_good_outputs", - "message": "Prompt has no properly connected outputs", + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } From 0d834e3a2ba6272b8cee6503f574c0f06002ddc3 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:59:30 -0500 Subject: [PATCH 11/82] Add missing input name/config --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 66753ff90..632aaa843 100644 --- a/execution.py +++ b/execution.py @@ -365,6 +365,8 @@ def validate_inputs(prompt, item, validated): "message": "Exception when validating node", "details": str(ex), "extra_info": { + "input_name": x, + "input_config": info, "error_type": error_type, "traceback": traceback.format_tb(tb) } From ffec815257ddf2371b880eafd575838210fcea07 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 12:48:06 -0500 Subject: [PATCH 12/82] Send back more information about exceptions that happen during execution --- execution.py | 173 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/execution.py b/execution.py index 632aaa843..5ed9ff348 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +258,44 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "node_id": error["node_id"], + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "message": error["message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "node_id": error["node_id"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +330,29 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", @@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] @@ -507,13 +580,13 @@ def validate_prompt(prompt): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] From 6b2a8a3845972bcff02184aaa8ded6eace8300ad Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:03:41 -0500 Subject: [PATCH 13/82] Show message in the frontend if prompt execution raises an exception --- execution.py | 14 +++++++++----- web/scripts/api.js | 6 ++++++ web/scripts/app.js | 35 ++++++++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 5ed9ff348..f79c3d351 100644 --- a/execution.py +++ b/execution.py @@ -258,27 +258,31 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), - - "node_id": error["node_id"], } self.server.send_sync("execution_interrupted", mes, self.server.client_id) else: if self.server.client_id is not None: mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), "message": error["message"], "exception_type": error["exception_type"], "traceback": error["traceback"], - "node_id": error["node_id"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } @@ -346,7 +350,7 @@ class PromptExecutor: # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: - self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) diff --git a/web/scripts/api.js b/web/scripts/api.js index 4f061c358..378165b3a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -88,6 +88,12 @@ class ComfyApi extends EventTarget { case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; default: if (this.#registered.has(msg.type)) { this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 21fe94802..e8ab32cf9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -784,8 +784,10 @@ export class ComfyApp { color = "red"; lineWidth = 2; } - - self.graphTime = Date.now() + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; @@ -895,6 +897,17 @@ export class ComfyApp { } }); + api.addEventListener("execution_start", ({ detail }) => { + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + api.init(); } @@ -1269,7 +1282,7 @@ export class ComfyApp { return { workflow, output }; } - #formatError(error) { + #formatPromptError(error) { if (error == null) { return "(unknown error)" } @@ -1294,6 +1307,18 @@ export class ComfyApp { return "(unknown error)" } + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1315,7 +1340,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - const formattedError = this.#formatError(error) + const formattedError = this.#formatPromptError(error) this.ui.dialog.show(formattedError); if (error.response) { this.lastPromptError = error.response; @@ -1419,7 +1444,7 @@ export class ComfyApp { clean() { this.nodeOutputs = {}; this.lastPromptError = null; - this.graphTime = null + this.lastExecutionError = null; } } From e2d080b6941783e50155f694c11ab0da1b1ae240 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:07:51 -0500 Subject: [PATCH 14/82] Return null for value format --- execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index f79c3d351..9cebce928 100644 --- a/execution.py +++ b/execution.py @@ -103,7 +103,9 @@ def get_output_data(obj, input_data_all): return output, ui def format_value(x): - if isinstance(x, (int, float, bool, str)): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): return x else: return str(x) From a9e7e237248296c8fe0d79991e0f8c2c0f2cf530 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:11:34 -0500 Subject: [PATCH 15/82] Fix --- execution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 9cebce928..ffea00a8c 100644 --- a/execution.py +++ b/execution.py @@ -499,9 +499,7 @@ def validate_inputs(prompt, item, validated): if r is not True: details = f"{x}" if r is not False: - details += f": {str(r)}" - else: - details += "." + details += f" - {str(r)}" error = { "type": "custom_validation_failed", From 62bdd9d26aba086ffbeedd118140e2806e6f4345 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 16:35:54 -0500 Subject: [PATCH 16/82] Catch typecast errors --- execution.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index ffea00a8c..6af58a673 100644 --- a/execution.py +++ b/execution.py @@ -424,7 +424,8 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "received_type": received_type + "received_type": received_type, + "linked_node": val } } errors.append(error) @@ -440,28 +441,44 @@ def validate_inputs(prompt, item, validated): valid = False exception_type = full_type_name(typ) reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + "traceback": traceback.format_tb(tb), + "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: From 52c9590b7b65dba86e8622f6ad38974bc4045f31 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 01:51:39 -0500 Subject: [PATCH 17/82] Exception message --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 6af58a673..52c264b0f 100644 --- a/execution.py +++ b/execution.py @@ -447,6 +447,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val From 03f2d0a764726641e848ba4e069c8809a502afdf Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 02:02:11 -0500 Subject: [PATCH 18/82] Rename exception message field --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 52c264b0f..1a9a1ff73 100644 --- a/execution.py +++ b/execution.py @@ -171,7 +171,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute error_details = { "node_id": unique_id, - "message": str(ex), + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, @@ -282,7 +282,7 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), - "message": error["message"], + "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], From 00646b0813e4f395725f3013f18b13a46f4d619d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 21:48:49 -0500 Subject: [PATCH 19/82] Bitwise operations for masks --- comfy_extras/nodes_mask.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9916f3b21..9134c24da 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() output = torch.clamp(output, 0.0, 1.0) From ad81fd682a5e5e7c1f258d7c11a000c0dfd07be3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:32:26 -0400 Subject: [PATCH 20/82] Fix issue with cancelling prompt. --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 1a9a1ff73..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -353,6 +353,7 @@ class PromptExecutor: success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) From f3ac938b4a5c031adb9ee2951f26360d6a2b36de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:42:53 -0400 Subject: [PATCH 21/82] Round the mask values for bitwise operations. --- comfy_extras/nodes_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9134c24da..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -194,11 +194,11 @@ class MaskComposite: elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion elif operation == "and": - output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": - output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": - output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) From 0fc483dcfdef457b50d3a67e66b4f463e6ef9d62 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 01:52:09 -0400 Subject: [PATCH 22/82] Refactor diffusers model convert code to be able to reuse it. --- comfy/diffusers_convert.py | 107 ----------------------------------- comfy/diffusers_load.py | 111 +++++++++++++++++++++++++++++++++++++ nodes.py | 4 +- 3 files changed, 113 insertions(+), 109 deletions(-) create mode 100644 comfy/diffusers_load.py diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..43877fb83 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,111 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/nodes.py b/nodes.py index 68010f040..90444a92c 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,7 @@ import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -377,7 +377,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From a532888846809de7b8890e8beb10ea87edf39d7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 02:02:09 -0400 Subject: [PATCH 23/82] Support VAEs in diffusers format. --- comfy/sd.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c6be900ad..4df149fe1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -504,10 +505,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() From 23ffafeb5d4a25bb5e41c34c9f04a0733643892c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sun, 28 May 2023 23:31:40 +0900 Subject: [PATCH 24/82] typo fix: field name in error message --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index e8ab32cf9..26670239b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1316,7 +1316,7 @@ export class ComfyApp { const nodeId = error.node_id const nodeType = error.node_type - return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` } async queuePrompt(number, batchCount = 1) { From b9818eb910b6ce683c38602c9b8fbd3979d97aaf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 02:48:50 -0400 Subject: [PATCH 25/82] Add route to get safetensors metadata: /view_metadata/loras?filename=lora.safetensors --- comfy/utils.py | 9 +++++++++ folder_paths.py | 2 ++ server.py | 25 ++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5ed9aaa02..4e84e870b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/folder_paths.py b/folder_paths.py index 20b461c94..19245a617 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None def get_filename_list(folder_name): global folder_names_and_paths diff --git a/server.py b/server.py index c0f79cbd5..72c565a63 100644 --- a/server.py +++ b/server.py @@ -22,7 +22,7 @@ except ImportError: import mimetypes from comfy.cli_args import args - +import comfy.utils @web.middleware async def cache_control(request: web.Request, handler): @@ -257,6 +257,29 @@ class PromptServer(): return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 560e9f7a43242c51da2589a33f659ecd41914b20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:29:00 -0400 Subject: [PATCH 26/82] Disable repo owner validation in update.py --- .ci/update_windows/update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: From 08abd838b82ea8d08a7e6f1484140d1694180381 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Tue, 30 May 2023 15:26:45 +0900 Subject: [PATCH 27/82] HOTFIX: Patched the conflict issue between the Combo Refresh feature and PrimitiveNodes. --- web/scripts/app.js | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 26670239b..64adc3e6a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1424,6 +1424,11 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { From eb448dd8e18125b569bea9002f909769678a6c43 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 12:36:41 -0400 Subject: [PATCH 28/82] Auto load model in lowvram if not enough memory. --- comfy/model_management.py | 46 ++++++++++++++++++++++++--------------- comfy/sd.py | 18 +++++++++++++-- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c15323219..10a706793 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,11 +30,12 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: import torch if directml_enabled: - total_vram = 4097 #TODO + pass #TODO else: try: import intel_extension_for_pytorch as ipex @@ -46,7 +46,7 @@ try: total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: - if total_vram <= 4096: + if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: @@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -103,18 +104,18 @@ if args.force_fp32: FORCE_FP32 = True -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) try: if torch.backends.mps.is_available(): @@ -199,22 +200,33 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.CPU: pass - elif vram_state == VRAMState.MPS: + elif vram_set_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True diff --git a/comfy/sd.py b/comfy/sd.py index 4df149fe1..ce17994f7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n From 2260802d90c41f1475a7bf2960aa018dc25f1001 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 16:44:09 -0400 Subject: [PATCH 29/82] Check if folder_name is valid instead of just throwing exception. --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 19245a617..fc37e52c7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -125,6 +125,8 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: From 04f4fba013da1f556fc310235d5a30c2bfe682e8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:01:49 -0500 Subject: [PATCH 30/82] Fix litegraph dialog CSS --- web/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + From 468c27afea29928d7d9fcd208e1137a36118ad13 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:06:17 -0500 Subject: [PATCH 31/82] Fix litegraph dialog z-index/font --- web/style.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/style.css b/web/style.css index 87f096e14..db82887c3 100644 --- a/web/style.css +++ b/web/style.css @@ -289,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; From 8ef197f02852b65509d6ebe06df8794b96a07f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:26:57 -0400 Subject: [PATCH 32/82] Keep list of filenames and only refresh it when something changes. --- folder_paths.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index fc37e52c7..f3d1b8773 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -31,6 +31,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -111,12 +113,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -136,13 +144,44 @@ def get_full_path(folder_name, filename): return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + + return (sorted(list(output_list)), output_folders) + +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return out[0] def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 1f34bf08f06550fb2f041188b5a01d395240be17 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 22:01:25 +0900 Subject: [PATCH 33/82] To support dynamic custom loading, separate the node registration process based on the defs in the registerNodes function. --- web/scripts/app.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 64adc3e6a..9ecad8489 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,6 +1010,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1082,8 +1087,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** From 8e8d6070f2e80aff0200bb3ad0f31716a98d5739 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 23:26:56 +0900 Subject: [PATCH 34/82] race condition patch --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9ecad8489..8a9c7ca49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,7 +1010,7 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); - this.registerNodesFromDefs(defs); + await this.registerNodesFromDefs(defs); await this.#invokeExtensionsAsync("registerCustomNodes"); } From 03da8a34265bb333d03a51d7503697b5ede9b335 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 May 2023 13:03:24 -0400 Subject: [PATCH 35/82] This is useless for inference. --- comfy/sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index ce17994f7..fa7bd8d32 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -743,7 +743,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -760,7 +760,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -1045,7 +1045,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ From d200fa131420a8871633b7321664db419aab2712 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 31 May 2023 19:00:01 -0500 Subject: [PATCH 36/82] Prevent callers from mutating folder lists --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index f3d1b8773..e179a28d4 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -181,7 +181,7 @@ def get_filename_list(folder_name): out = get_filename_list_(folder_name) global filename_list_cache filename_list_cache[folder_name] = out - return out[0] + return list(out[0]) def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 94680732d32b4b540251c122aee36df8d37266e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 03:52:51 -0400 Subject: [PATCH 37/82] Empty cache on mps. --- comfy/model_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10a706793..60bcd786b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -389,7 +389,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global vram_state + if vram_state == VRAMState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda From 5c38958e49efd11b5234cb5ff472d752698c5090 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 04:04:35 -0400 Subject: [PATCH 38/82] Tweak lowvram model memory so it's closer to what it was before. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 60bcd786b..e9af7f3a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -207,7 +207,7 @@ def load_model_gpu(model): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM From 1bbd3f7fe16e6637bba232059d004a5fe7804a30 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 22:15:06 -0500 Subject: [PATCH 39/82] Send back prompt number from prompt/ endpoint --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 72c565a63..0b64df147 100644 --- a/server.py +++ b/server.py @@ -361,7 +361,7 @@ class PromptServer(): prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - return web.json_response({"prompt_id": prompt_id}) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: print("invalid prompt:", valid[1]) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) From b5dd15c67ad3f4dbdc23811f40a4c121e318bfe9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 23:26:23 -0500 Subject: [PATCH 40/82] System stats endpoint --- comfy/model_management.py | 27 +++++++++++++++++++++++++++ server.py | 24 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e9af7f3a7..3b7b1dbf1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -308,6 +308,33 @@ def pytorch_attention_flash_attention(): return True return False +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + mem_total_torch + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index 0b64df147..acbc88f66 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import torch from PIL import Image from io import BytesIO @@ -23,6 +24,7 @@ except ImportError: import mimetypes from comfy.cli_args import args import comfy.utils +import comfy.model_management @web.middleware async def cache_control(request: web.Request, handler): @@ -280,6 +282,28 @@ class PromptServer(): return web.Response(status=404) return web.json_response(dt["__metadata__"]) + @routes.get("/system_stats") + async def get_queue(request): + device_index = comfy.model_management.get_torch_device() + device = torch.device(device_index) + device_name = comfy.model_management.get_torch_device_name(device_index) + vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) + vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) + system_stats = { + "devices": [ + { + "name": device_name, + "type": device.type, + "index": device.index, + "vram_total": vram_total, + "vram_free": vram_free, + "torch_vram_total": torch_vram_total, + "torch_vram_free": torch_vram_free, + } + ] + } + return web.json_response(system_stats) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 499641ebf1be190e20624ee352e9dc88884e3df1 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 2 Jun 2023 00:14:41 -0500 Subject: [PATCH 41/82] More accurate total --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3b7b1dbf1..0ea0c71e5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -328,7 +328,7 @@ def get_total_memory(dev=None, torch_total_too=False): mem_reserved = stats['reserved_bytes.all.current'] _, mem_total_cuda = torch.cuda.mem_get_info(dev) mem_total_torch = mem_reserved - mem_total = mem_total_cuda + mem_total_torch + mem_total = mem_total_cuda if torch_total_too: return (mem_total, mem_total_torch) From 67892b5ac584ff8def01a5852246c364f8408d95 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 15:05:25 -0400 Subject: [PATCH 42/82] Refactor and improve model_management code related to free memory. --- comfy/model_management.py | 131 +++++++++++++++++++------------------- server.py | 6 +- 2 files changed, 68 insertions(+), 69 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0ea0c71e5..9c3147d76 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1,6 +1,7 @@ import psutil from enum import Enum from comfy.cli_args import args +import torch class VRAMState(Enum): CPU = 0 @@ -33,28 +34,67 @@ if args.directml is not None: lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: - import torch - if directml_enabled: - pass #TODO - else: - try: - import intel_extension_for_pytorch as ipex - if torch.xpu.is_available(): - xpu_available = True - total_vram = torch.xpu.get_device_properties(torch.xpu.current_device()).total_memory / (1024 * 1024) - except: - total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) - total_ram = psutil.virtual_memory().total / (1024 * 1024) - if not args.normalvram and not args.cpu: - if lowvram_available and total_vram <= 4096: - print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") - set_vram_to = VRAMState.LOW_VRAM - elif total_vram > total_ram * 1.1 and total_vram > 14336: - print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") - vram_state = VRAMState.HIGH_VRAM + import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): + xpu_available = True except: pass +def get_torch_device(): + global xpu_available + global directml_enabled + if directml_enabled: + global directml_device + return directml_device + if vram_state == VRAMState.MPS: + return torch.device("mps") + if vram_state == VRAMState.CPU: + return torch.device("cpu") + else: + if xpu_available: + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) + +def get_total_memory(dev=None, torch_total_too=False): + global xpu_available + global directml_enabled + if dev is None: + dev = get_torch_device() + + if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): + mem_total = psutil.virtual_memory().total + mem_total_torch = mem_total + else: + if directml_enabled: + mem_total = 1024 * 1024 * 1024 #TODO + mem_total_torch = mem_total + elif xpu_available: + mem_total = torch.xpu.get_device_properties(dev).total_memory + mem_total_torch = mem_total + else: + stats = torch.cuda.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_cuda = torch.cuda.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_cuda + + if torch_total_too: + return (mem_total, mem_total_torch) + else: + return mem_total + +total_vram = get_total_memory(get_torch_device()) / (1024 * 1024) +total_ram = psutil.virtual_memory().total / (1024 * 1024) +print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram)) +if not args.normalvram and not args.cpu: + if lowvram_available and total_vram <= 4096: + print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") + set_vram_to = VRAMState.LOW_VRAM + elif total_vram > total_ram * 1.1 and total_vram > 14336: + print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram") + vram_state = VRAMState.HIGH_VRAM + try: OOM_EXCEPTION = torch.cuda.OutOfMemoryError except: @@ -128,29 +168,17 @@ if args.cpu: print(f"Set vram state to: {vram_state.name}") -def get_torch_device(): - global xpu_available - global directml_enabled - if directml_enabled: - global directml_device - return directml_device - if vram_state == VRAMState.MPS: - return torch.device("mps") - if vram_state == VRAMState.CPU: - return torch.device("cpu") - else: - if xpu_available: - return torch.device("xpu") - else: - return torch.cuda.current_device() - def get_torch_device_name(device): if hasattr(device, 'type'): - return "{}".format(device.type) - return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) + if device.type == "cuda": + return "{} {}".format(device, torch.cuda.get_device_name(device)) + else: + return "{}".format(device.type) + else: + return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) try: - print("Using device:", get_torch_device_name(get_torch_device())) + print("Device:", get_torch_device_name(get_torch_device())) except: print("Could not pick default device.") @@ -308,33 +336,6 @@ def pytorch_attention_flash_attention(): return True return False -def get_total_memory(dev=None, torch_total_too=False): - global xpu_available - global directml_enabled - if dev is None: - dev = get_torch_device() - - if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): - mem_total = psutil.virtual_memory().total - else: - if directml_enabled: - mem_total = 1024 * 1024 * 1024 #TODO - mem_total_torch = mem_total - elif xpu_available: - mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total - else: - stats = torch.cuda.memory_stats(dev) - mem_reserved = stats['reserved_bytes.all.current'] - _, mem_total_cuda = torch.cuda.mem_get_info(dev) - mem_total_torch = mem_reserved - mem_total = mem_total_cuda - - if torch_total_too: - return (mem_total, mem_total_torch) - else: - return mem_total - def get_free_memory(dev=None, torch_free_too=False): global xpu_available global directml_enabled diff --git a/server.py b/server.py index acbc88f66..5be822a6f 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob -import torch from PIL import Image from io import BytesIO @@ -284,9 +283,8 @@ class PromptServer(): @routes.get("/system_stats") async def get_queue(request): - device_index = comfy.model_management.get_torch_device() - device = torch.device(device_index) - device_name = comfy.model_management.get_torch_device_name(device_index) + device = comfy.model_management.get_torch_device() + device_name = comfy.model_management.get_torch_device_name(device) vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True) vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True) system_stats = { From 871a86593ae7eb96518d326c83cfded5d41c6fa6 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:34:47 -0400 Subject: [PATCH 43/82] Smarter filename list caching. --- folder_paths.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index e179a28d4..8cee6afde 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,4 +1,5 @@ import os +import time supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) @@ -154,7 +155,7 @@ def get_filename_list_(folder_name): output_list.update(filter_files_extensions(files, folders[1])) output_folders = {**output_folders, **folders_all} - return (sorted(list(output_list)), output_folders) + return (sorted(list(output_list)), output_folders, time.perf_counter()) def cached_filename_list_(folder_name): global filename_list_cache @@ -162,6 +163,8 @@ def cached_filename_list_(folder_name): if folder_name not in filename_list_cache: return None out = filename_list_cache[folder_name] + if time.perf_counter() < (out[2] + 0.5): + return out for x in out[1]: time_modified = out[1][x] folder = x From 66e588d837275b26b428f737692357090ad41426 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 2 Jun 2023 16:48:56 -0400 Subject: [PATCH 44/82] Ignore folder path directories that don't exist. --- folder_paths.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 8cee6afde..a1bf1444d 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -173,8 +173,9 @@ def cached_filename_list_(folder_name): folders = folder_names_and_paths[folder_name] for x in folders[0]: - if x not in out[1]: - return None + if os.path.isdir(x): + if x not in out[1]: + return None return out From 700491d81a9faf5370a0c54d869e902bbfc839ec Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 01:47:21 -0400 Subject: [PATCH 45/82] Implement global average pooling for controlnet. --- comfy/sd.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index fa7bd8d32..336fee4a6 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -621,7 +621,7 @@ def broadcast_image_to(tensor, target_batch_size, batched_number): return torch.cat([tensor] * batched_number, dim=0) class ControlNet: - def __init__(self, control_model, device=None): + def __init__(self, control_model, global_average_pooling=False, device=None): self.control_model = control_model self.cond_hint_original = None self.cond_hint = None @@ -630,6 +630,7 @@ class ControlNet: device = model_management.get_torch_device() self.device = device self.previous_controlnet = None + self.global_average_pooling = global_average_pooling def get_control(self, x_noisy, t, cond_txt, batched_number): control_prev = None @@ -665,6 +666,9 @@ class ControlNet: key = 'output' index = i x = control[i] + if self.global_average_pooling: + x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3]) + x *= self.strength if x.dtype != output_dtype and not autocast_enabled: x = x.to(output_dtype) @@ -695,7 +699,7 @@ class ControlNet: self.cond_hint = None def copy(self): - c = ControlNet(self.control_model) + c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling) c.cond_hint_original = self.cond_hint_original c.strength = self.strength return c @@ -790,7 +794,11 @@ def load_controlnet(ckpt_path, model=None): if use_fp16: control_model = control_model.half() - control = ControlNet(control_model) + global_average_pooling = False + if ckpt_path.endswith("_shuffle.pth") or ckpt_path.endswith("_shuffle.safetensors") or ckpt_path.endswith("_shuffle_fp16.safetensors"): #TODO: smarter way of enabling global_average_pooling + global_average_pooling = True + + control = ControlNet(control_model, global_average_pooling=global_average_pooling) return control class T2IAdapter: From 0a5fefd6213e3116359e0738533a9e3b733506c5 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:05:37 -0400 Subject: [PATCH 46/82] Cleanups and fixes for model_management.py Hopefully fix regression on MPS and CPU. --- comfy/model_management.py | 63 ++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3147d76..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,16 +4,22 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 @@ -40,15 +46,25 @@ try: except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + def get_torch_device(): global xpu_available global directml_enabled + global cpu_state if directml_enabled: global directml_device return directml_device - if vram_state == VRAMState.MPS: + if cpu_state == CPUState.MPS: return torch.device("mps") - if vram_state == VRAMState.CPU: + if cpu_state == CPUState.CPU: return torch.device("cpu") else: if xpu_available: @@ -143,8 +159,6 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - - if lowvram_available: try: import accelerate @@ -157,17 +171,15 @@ if lowvram_available: lowvram_available = False -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -241,13 +253,9 @@ def load_model_gpu(model): current_loaded_model = model - if vram_set_state == VRAMState.CPU: + if vram_set_state == VRAMState.DISABLED: pass - elif vram_set_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: @@ -263,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -308,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -380,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -417,8 +426,8 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - global vram_state - if vram_state == VRAMState.MPS: + global cpu_state + if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif xpu_available: torch.xpu.empty_cache() From 32f282c861eabcee42fdec702b96ebc8924c9834 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:19:10 -0400 Subject: [PATCH 47/82] Search box style fix. --- web/style.css | 1 + 1 file changed, 1 insertion(+) diff --git a/web/style.css b/web/style.css index db82887c3..47571a16e 100644 --- a/web/style.css +++ b/web/style.css @@ -336,6 +336,7 @@ button.comfy-queue-btn { z-index: 9999 !important; background-color: var(--comfy-menu-bg) !important; overflow: hidden; + display: block; } .litegraph.litesearchbox input, From c092ffcc18f0a44c062fe914ebda05b29bdcfbc0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:46:52 -0400 Subject: [PATCH 48/82] Latest litegraph from upstream. --- web/lib/litegraph.core.js | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 95f4a2735..908ed5f16 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -8099,11 +8099,15 @@ LGraphNode.prototype.executeAction = function(action) bgcolor = bgcolor || LiteGraph.NODE_DEFAULT_COLOR; hovercolor = hovercolor || "#555"; textcolor = textcolor || LiteGraph.NODE_TEXT_COLOR; - var yFix = y + LiteGraph.NODE_TITLE_HEIGHT + 2; // fix the height with the title - var pos = this.mouse; - var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); - pos = this.last_click_position; - var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,yFix,w,h ); + var pos = this.ds.convertOffsetToCanvas(this.graph_mouse); + var hover = LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); + pos = this.last_click_position ? [this.last_click_position[0], this.last_click_position[1]] : null; + if(pos) { + var rect = this.canvas.getBoundingClientRect(); + pos[0] -= rect.left; + pos[1] -= rect.top; + } + var clicked = pos && LiteGraph.isInsideRectangle( pos[0], pos[1], x,y,w,h ); ctx.fillStyle = hover ? hovercolor : bgcolor; if(clicked) From 0764bb5218ea49fdeeaebbfc10c6f5b87a8bc879 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 3 Jun 2023 11:47:20 -0400 Subject: [PATCH 49/82] Move node properties panel from double click to menu option. --- web/lib/litegraph.core.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/web/lib/litegraph.core.js b/web/lib/litegraph.core.js index 908ed5f16..a60848d77 100644 --- a/web/lib/litegraph.core.js +++ b/web/lib/litegraph.core.js @@ -7294,10 +7294,6 @@ LGraphNode.prototype.executeAction = function(action) if (this.onShowNodePanel) { this.onShowNodePanel(n); } - else - { - this.showShowNodePanel(n); - } if (this.onNodeDblClicked) { this.onNodeDblClicked(n); @@ -13071,6 +13067,10 @@ LGraphNode.prototype.executeAction = function(action) has_submenu: true, callback: LGraphCanvas.onShowMenuNodeProperties }, + { + content: "Properties Panel", + callback: function(item, options, e, menu, node) { LGraphCanvas.active_canvas.showShowNodePanel(node) } + }, null, { content: "Title", From 126b4050dc34daabca51c236bfb5cc31dd48056d Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sun, 4 Jun 2023 01:25:49 +0900 Subject: [PATCH 50/82] Crash fix for intermittent crashes that occur when opening MaskEditor. (#732) --- web/extensions/core/maskeditor.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 4b0c12747..6cb3a5385 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -314,11 +314,11 @@ class MaskEditorDialog extends ComfyDialog { imgCtx.drawImage(orig_image, 0, 0, drawWidth, drawHeight); // update mask - backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCanvas.width = drawWidth; maskCanvas.height = drawHeight; maskCanvas.style.top = imgCanvas.offsetTop + "px"; maskCanvas.style.left = imgCanvas.offsetLeft + "px"; + backupCtx.drawImage(maskCanvas, 0, 0, maskCanvas.width, maskCanvas.height, 0, 0, backupCanvas.width, backupCanvas.height); maskCtx.drawImage(backupCanvas, 0, 0, backupCanvas.width, backupCanvas.height, 0, 0, maskCanvas.width, maskCanvas.height); }); From fed0a4dd29852e4808382ef9428a2256214667bf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 4 Jun 2023 17:51:04 -0400 Subject: [PATCH 51/82] Some comments to say what the vram state options mean. --- comfy/model_management.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index a492ca6b9..1a8a1be17 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,12 +4,12 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - DISABLED = 0 - NO_VRAM = 1 + DISABLED = 0 #No vram present: no need to move models to vram + NO_VRAM = 1 #Very low vram: enable all the options to save vram LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - SHARED = 5 + SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both. class CPUState(Enum): GPU = 0 From 9f3a19b72817775b4d567a9a0b7ac870698eb839 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Mon, 5 Jun 2023 14:49:43 +0900 Subject: [PATCH 52/82] improve: lightweight preview to reduce network traffic (#733) * To reduce bandwidth traffic in a remote environment, a lossy compression-based preview mode is provided for displaying simple visualizations in node-based widgets. * Added 'preview=[image format]' option to the '/view' API. * Updated node to use preview for displaying images as widgets. * Excluded preview usage in the open image, save image, mask editor where the original data is required. * Made preview_format parameterizable for extensibility. * default preview format changed: jpeg -> webp * Support advanced preview_format option. - grayscale option for visual debugging - quality option for aggressive reducing L?;format;quality? ex) jpeg => rgb, jpeg, quality 90 L;webp;80 => grayscale, webp, quality 80 L;png => grayscale, png, quality 90 webp;50 => rgb, webp, quality 50 * move comment * * add settings for preview_format * default value is ''(= don't reencode) --------- Co-authored-by: Lt.Dr.Data --- server.py | 22 ++++++++++++++++++++++ web/extensions/core/maskeditor.js | 4 +++- web/scripts/app.js | 22 ++++++++++++++++++---- web/scripts/ui.js | 19 +++++++++++++++++++ web/scripts/widgets.js | 2 +- 5 files changed, 63 insertions(+), 6 deletions(-) diff --git a/server.py b/server.py index 5be822a6f..b0dd33828 100644 --- a/server.py +++ b/server.py @@ -217,6 +217,28 @@ class PromptServer(): file = os.path.join(output_dir, filename) if os.path.isfile(file): + if 'preview' in request.rel_url.query: + with Image.open(file) as img: + preview_info = request.rel_url.query['preview'].split(';') + + if preview_info[0] == "L" or preview_info[0] == "l": + img = img.convert("L") + image_format = preview_info[1] + else: + img = img.convert("RGB") # jpeg doesn't support RGBA + image_format = preview_info[0] + + quality = 90 + if preview_info[-1].isdigit(): + quality = int(preview_info[-1]) + + buffer = BytesIO() + img.save(buffer, format=image_format, optimize=True, quality=quality) + buffer.seek(0) + + return web.Response(body=buffer.read(), content_type=f'image/{image_format}', + headers={"Content-Disposition": f"filename=\"{filename}\""}) + if 'channel' not in request.rel_url.query: channel = 'rgba' else: diff --git a/web/extensions/core/maskeditor.js b/web/extensions/core/maskeditor.js index 6cb3a5385..764164d5e 100644 --- a/web/extensions/core/maskeditor.js +++ b/web/extensions/core/maskeditor.js @@ -41,7 +41,7 @@ async function uploadMask(filepath, formData) { }); ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']] = new Image(); - ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString(); + ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src = "/view?" + new URLSearchParams(filepath).toString() + app.getPreviewFormatParam(); if(ComfyApp.clipspace.images) ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']] = filepath; @@ -335,6 +335,7 @@ class MaskEditorDialog extends ComfyDialog { const alpha_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src) alpha_url.searchParams.delete('channel'); + alpha_url.searchParams.delete('preview'); alpha_url.searchParams.set('channel', 'a'); touched_image.src = alpha_url; @@ -345,6 +346,7 @@ class MaskEditorDialog extends ComfyDialog { const rgb_url = new URL(ComfyApp.clipspace.imgs[ComfyApp.clipspace['selectedIndex']].src); rgb_url.searchParams.delete('channel'); + rgb_url.searchParams.delete('preview'); rgb_url.searchParams.set('channel', 'rgb'); orig_image.src = rgb_url; this.image = orig_image; diff --git a/web/scripts/app.js b/web/scripts/app.js index 8a9c7ca49..95447ffa0 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -51,6 +51,14 @@ export class ComfyApp { this.shiftDown = false; } + getPreviewFormatParam() { + let preview_format = this.ui.settings.getSettingValue("Comfy.PreviewFormat"); + if(preview_format) + return `&preview=${preview_format}`; + else + return ""; + } + static isImageNode(node) { return node.imgs || (node && node.widgets && node.widgets.findIndex(obj => obj.name === 'image') >= 0); } @@ -231,14 +239,20 @@ export class ComfyApp { options.unshift( { content: "Open Image", - callback: () => window.open(img.src, "_blank"), + callback: () => { + let url = new URL(img.src); + url.searchParams.delete('preview'); + window.open(url, "_blank") + }, }, { content: "Save Image", callback: () => { const a = document.createElement("a"); - a.href = img.src; - a.setAttribute("download", new URLSearchParams(new URL(img.src).search).get("filename")); + let url = new URL(img.src); + url.searchParams.delete('preview'); + a.href = url; + a.setAttribute("download", new URLSearchParams(url.search).get("filename")); document.body.append(a); a.click(); requestAnimationFrame(() => a.remove()); @@ -365,7 +379,7 @@ export class ComfyApp { const img = new Image(); img.onload = () => r(img); img.onerror = () => r(null); - img.src = "/view?" + new URLSearchParams(src).toString(); + img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); }); }) ).then((imgs) => { diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 2c9043d00..6b764d43c 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -462,6 +462,25 @@ export class ComfyUI { defaultValue: true, }); + /** + * file format for preview + * + * L?;format;quality + * + * ex) + * L;webp;50 -> grayscale, webp, quality 50 + * jpeg;80 -> rgb, jpeg, quality 80 + * png -> rgb, png, default quality(=90) + * + * @type {string} + */ + const previewImage = this.settings.addSetting({ + id: "Comfy.PreviewFormat", + name: "When displaying a preview in the image widget, convert it to a lightweight image. (webp, jpeg, webp;50, ...)", + type: "string", + defaultValue: "", + }); + const fileInput = $el("input", { id: "comfy-file-input", type: "file", diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 82168b08b..d6faaddbf 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -303,7 +303,7 @@ export const ComfyWidgets = { subfolder = name.substring(0, folder_separator); name = name.substring(folder_separator + 1); } - img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}`; + img.src = `/view?filename=${name}&type=input&subfolder=${subfolder}${app.getPreviewFormatParam()}`; node.setSizeForImage?.(); } From 2ec980bb9f3e63fbc605e632d1ebe8837083aaaf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 5 Jun 2023 01:38:32 -0400 Subject: [PATCH 53/82] Limit preview to webp and RGB jpeg. --- server.py | 13 ++++++------- web/scripts/ui.js | 5 ++--- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/server.py b/server.py index b0dd33828..c0b4729de 100644 --- a/server.py +++ b/server.py @@ -221,19 +221,18 @@ class PromptServer(): with Image.open(file) as img: preview_info = request.rel_url.query['preview'].split(';') - if preview_info[0] == "L" or preview_info[0] == "l": - img = img.convert("L") - image_format = preview_info[1] - else: - img = img.convert("RGB") # jpeg doesn't support RGBA - image_format = preview_info[0] + image_format = preview_info[0] + if image_format not in ['webp', 'jpeg']: + image_format = 'webp' quality = 90 if preview_info[-1].isdigit(): quality = int(preview_info[-1]) buffer = BytesIO() - img.save(buffer, format=image_format, optimize=True, quality=quality) + if image_format in ['jpeg']: + img = img.convert("RGB") + img.save(buffer, format=image_format, quality=quality) buffer.seek(0) return web.Response(body=buffer.read(), content_type=f'image/{image_format}', diff --git a/web/scripts/ui.js b/web/scripts/ui.js index 6b764d43c..a26eedec3 100644 --- a/web/scripts/ui.js +++ b/web/scripts/ui.js @@ -465,12 +465,11 @@ export class ComfyUI { /** * file format for preview * - * L?;format;quality + * format;quality * * ex) - * L;webp;50 -> grayscale, webp, quality 50 + * webp;50 -> webp, quality 50 * jpeg;80 -> rgb, jpeg, quality 80 - * png -> rgb, png, default quality(=90) * * @type {string} */ From b4f434ee66b109df67be83265c1b158e3794b241 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 20:43:29 -0500 Subject: [PATCH 54/82] Preview sampled images with TAESD --- comfy/taesd/taesd.py | 65 +++++++++++++++ comfy/utils.py | 4 +- main.py | 6 +- nodes.py | 119 ++++++++++++++++++++++++++-- server.py | 39 +++++++-- web/extensions/core/colorPalette.js | 1 + web/scripts/api.js | 84 +++++++++++++------- web/scripts/app.js | 58 ++++++++++++-- 8 files changed, 324 insertions(+), 52 deletions(-) create mode 100644 comfy/taesd/taesd.py diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py new file mode 100644 index 000000000..e64067454 --- /dev/null +++ b/comfy/taesd/taesd.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +""" +Tiny AutoEncoder for Stable Diffusion +(DNN for encoding / decoding SD's latent space) +""" +import torch +import torch.nn as nn + +def conv(n_in, n_out, **kwargs): + return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class Block(nn.Module): + def __init__(self, n_in, n_out): + super().__init__() + self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) + self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.fuse = nn.ReLU() + def forward(self, x): + return self.fuse(self.conv(x) + self.skip(x)) + +def Encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + +def Decoder(): + return nn.Sequential( + Clamp(), conv(4, 64), nn.ReLU(), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False), + Block(64, 64), conv(64, 3), + ) + +class TAESD(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth", decoder_path="taesd_decoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = Encoder() + self.decoder = Decoder() + if encoder_path is not None: + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu")) + if decoder_path is not None: + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu")) + + @staticmethod + def scale_latents(x): + """raw latents -> [0, 1]""" + return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) + + @staticmethod + def unscale_latents(x): + """[0, 1] -> raw latents""" + return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) diff --git a/comfy/utils.py b/comfy/utils.py index 4e84e870b..291c62e42 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -197,14 +197,14 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK - def update_absolute(self, value, total=None): + def update_absolute(self, value, total=None, preview=None): if total is not None: self.total = total if value > self.total: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total) + self.hook(self.current, self.total, preview) def update(self, value): self.update_absolute(self.current + value) diff --git a/main.py b/main.py index 50d3b9a62..908ff7af7 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import yaml import execution import folder_paths import server +from server import BinaryEventTypes from nodes import init_custom_nodes @@ -40,8 +41,11 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - def hook(value, total): + def hook(value, total, preview_bytes_jpeg): server.send_sync("progress", { "value": value, "max": total}, server.client_id) + if preview_bytes_jpeg is not None: + server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes_jpeg, server.client_id) + pass comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): diff --git a/nodes.py b/nodes.py index 90444a92c..a80f81933 100644 --- a/nodes.py +++ b/nodes.py @@ -7,6 +7,8 @@ import hashlib import traceback import math import time +import struct +from io import BytesIO from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo @@ -22,6 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils +from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -38,6 +41,7 @@ def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 +MAX_PREVIEW_RESOLUTION = 512 class CLIPTextEncode: @classmethod @@ -171,6 +175,21 @@ class VAEDecodeTiled: def decode(self, vae, samples): return (vae.decode_tiled(samples["samples"]), ) +class TAESDDecode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "decode" + + CATEGORY = "latent" + + def decode(self, taesd, samples): + device = comfy.model_management.get_torch_device() + # [B, C, H, W] -> [B, H, W, C] + pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1) + return (pixels, ) + class VAEEncode: @classmethod def INPUT_TYPES(s): @@ -248,6 +267,21 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) +class TAESDEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "encode" + + CATEGORY = "latent" + + def encode(self, taesd, pixels): + device = comfy.model_management.get_torch_device() + # [B, H, W, C] -> [B, C, H, W] + samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device) + return ({"samples": samples}, ) + class SaveLatent: def __init__(self): @@ -464,6 +498,26 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) +class TAESDLoader: + @classmethod + def INPUT_TYPES(s): + model_list = folder_paths.get_filename_list("taesd") + return {"required": { + "encoder_name": (model_list, { "default": "taesd_encoder.pth" }), + "decoder_name": (model_list, { "default": "taesd_decoder.pth" }) + }} + RETURN_TYPES = ("TAESD",) + FUNCTION = "load_taesd" + + CATEGORY = "loaders" + + def load_taesd(self, encoder_name, decoder_name): + device = comfy.model_management.get_torch_device() + encoder_path = folder_paths.get_full_path("taesd", encoder_name) + decoder_path = folder_paths.get_full_path("taesd", decoder_name) + taesd = TAESD(encoder_path, decoder_path).to(device) + return (taesd,) + class ControlNetLoader: @classmethod def INPUT_TYPES(s): @@ -931,7 +985,37 @@ class SetLatentNoiseMask: s["noise_mask"] = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) return (s,) -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): + +def decode_latent_to_preview_image(taesd, device, preview_format, x0): + x_sample = taesd.decoder(x0.to(device))[0].detach() + x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2] + x_sample = x_sample * 0.5 + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + + preview_image = Image.fromarray(x_sample) + + if preview_image.size[0] > MAX_PREVIEW_RESOLUTION or preview_image.size[1] > MAX_PREVIEW_RESOLUTION: + preview_image.thumbnail((MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) + + preview_type = 1 + if preview_format == "JPEG": + preview_type = 1 + elif preview_format == "PNG": + preview_type = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", preview_type) + bytesIO.write(header) + preview_image.save(bytesIO, format=preview_format) + preview_bytes = bytesIO.getvalue() + + return preview_bytes + + +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -945,9 +1029,16 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if "noise_mask" in latent: noise_mask = latent["noise_mask"] + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): - pbar.update_absolute(step + 1, total_steps) + preview_bytes = None + if taesd: + preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0) + pbar.update_absolute(step + 1, total_steps, preview_bytes) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, @@ -970,15 +1061,18 @@ class KSampler: "negative": ("CONDITIONING", ), "latent_image": ("LATENT", ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }} + }, + "optional": { + "taesd": ("TAESD",) + }} RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None): + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd) class KSamplerAdvanced: @classmethod @@ -997,21 +1091,24 @@ class KSamplerAdvanced: "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - }} + }, + "optional": { + "taesd": ("TAESD",) + }} RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) + return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd) class SaveImage: def __init__(self): @@ -1270,6 +1367,9 @@ NODE_CLASS_MAPPINGS = { "VAEEncode": VAEEncode, "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, + "TAESDDecode": TAESDDecode, + "TAESDEncode": TAESDEncode, + "TAESDLoader": TAESDLoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentUpscaleBy": LatentUpscaleBy, @@ -1324,6 +1424,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", + "TAESDLoader": "Load TAESD", "LoraLoader": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", @@ -1346,6 +1447,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SetLatentNoiseMask": "Set Latent Noise Mask", "VAEDecode": "VAE Decode", "VAEEncode": "VAE Encode", + "TAESDDecode": "TAESD Decode", + "TAESDEncode": "TAESD Encode", "LatentRotate": "Rotate Latent", "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", diff --git a/server.py b/server.py index c0b4729de..174d38af1 100644 --- a/server.py +++ b/server.py @@ -7,6 +7,7 @@ import execution import uuid import json import glob +import struct from PIL import Image from io import BytesIO @@ -25,6 +26,11 @@ from comfy.cli_args import args import comfy.utils import comfy.model_management + +class BinaryEventTypes: + PREVIEW_IMAGE = 1 + + @web.middleware async def cache_control(request: web.Request, handler): response: web.Response = await handler(request) @@ -457,16 +463,37 @@ class PromptServer(): return prompt_info async def send(self, event, data, sid=None): - message = {"type": event, "data": data} - - if isinstance(message, str) == False: - message = json.dumps(message) + if isinstance(data, (bytes, bytearray)): + await self.send_bytes(event, data, sid) + else: + await self.send_json(event, data, sid) + + def encode_bytes(self, event, data): + if not isinstance(event, int): + raise RuntimeError(f"Binary event types must be integers, got {event}") + + packed = struct.pack(">I", event) + message = bytearray(packed) + message.extend(data) + return message + + async def send_bytes(self, event, data, sid=None): + message = self.encode_bytes(event, data) if sid is None: for ws in self.sockets.values(): - await ws.send_str(message) + await ws.send_bytes(message) elif sid in self.sockets: - await self.sockets[sid].send_str(message) + await self.sockets[sid].send_bytes(message) + + async def send_json(self, event, data, sid=None): + message = {"type": event, "data": data} + + if sid is None: + for ws in self.sockets.values(): + await ws.send_json(message) + elif sid in self.sockets: + await self.sockets[sid].send_json(message) def send_sync(self, event, data, sid=None): self.loop.call_soon_threadsafe( diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index bfcd847a3..84c2a3d10 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -21,6 +21,7 @@ const colorPalettes = { "MODEL": "#B39DDB", // light lavender-purple "STYLE_MODEL": "#C2FFAE", // light green-yellow "VAE": "#FF6E6E", // bright red + "TAESD": "#DCC274", // cheesecake }, "litegraph_base": { "NODE_TITLE_COLOR": "#999", diff --git a/web/scripts/api.js b/web/scripts/api.js index 378165b3a..780c74b30 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -42,6 +42,7 @@ class ComfyApi extends EventTarget { this.socket = new WebSocket( `ws${window.location.protocol === "https:" ? "s" : ""}://${location.host}/ws${existingSession}` ); + this.socket.binaryType = "arraybuffer"; this.socket.addEventListener("open", () => { opened = true; @@ -70,39 +71,66 @@ class ComfyApi extends EventTarget { this.socket.addEventListener("message", (event) => { try { - const msg = JSON.parse(event.data); - switch (msg.type) { - case "status": - if (msg.data.sid) { - this.clientId = msg.data.sid; - window.name = this.clientId; + if (event.data instanceof ArrayBuffer) { + const view = new DataView(event.data); + const eventType = view.getUint32(0); + const buffer = event.data.slice(4); + console.error("BINARY", eventType); + switch (eventType) { + case 1: + const view2 = new DataView(event.data); + const imageType = view2.getUint32(0) + let imageMime + switch (imageType) { + case 1: + default: + imageMime = "image/jpeg"; + break; + case 2: + imageMime = "image/png" } - this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); - break; - case "progress": - this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); - break; - case "executing": - this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); - break; - case "executed": - this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); - break; - case "execution_start": - this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); - break; - case "execution_error": - this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + const jpegBlob = new Blob([buffer.slice(4)], { type: imageMime }); + this.dispatchEvent(new CustomEvent("b_preview", { detail: jpegBlob })); break; default: - if (this.#registered.has(msg.type)) { - this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); - } else { - throw new Error("Unknown message type"); - } + throw new Error(`Unknown binary websocket message of type ${eventType}`); + } + } + else { + const msg = JSON.parse(event.data); + switch (msg.type) { + case "status": + if (msg.data.sid) { + this.clientId = msg.data.sid; + window.name = this.clientId; + } + this.dispatchEvent(new CustomEvent("status", { detail: msg.data.status })); + break; + case "progress": + this.dispatchEvent(new CustomEvent("progress", { detail: msg.data })); + break; + case "executing": + this.dispatchEvent(new CustomEvent("executing", { detail: msg.data.node })); + break; + case "executed": + this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); + break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; + default: + if (this.#registered.has(msg.type)) { + this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); + } else { + throw new Error(`Unknown message type ${msg.type}`); + } + } } } catch (error) { - console.warn("Unhandled message:", event.data); + console.warn("Unhandled message:", event.data, error); } }); } diff --git a/web/scripts/app.js b/web/scripts/app.js index 95447ffa0..495d43e1f 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -44,6 +44,12 @@ export class ComfyApp { */ this.nodeOutputs = {}; + /** + * Stores the preview image data for each node + * @type {Record} + */ + this.nodePreviewImages = {}; + /** * If the shift key on the keyboard is pressed * @type {boolean} @@ -367,29 +373,52 @@ export class ComfyApp { node.prototype.onDrawBackground = function (ctx) { if (!this.flags.collapsed) { + let imgURLs = [] + let imagesChanged = false + const output = app.nodeOutputs[this.id + ""]; if (output && output.images) { if (this.images !== output.images) { this.images = output.images; - this.imgs = null; - this.imageIndex = null; + imagesChanged = true; + imgURLs = imgURLs.concat(output.images.map(params => { + return "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); + })) + } + } + + const preview = app.nodePreviewImages[this.id + ""] + if (this.preview !== preview) { + this.preview = preview + imagesChanged = true; + if (preview != null) { + imgURLs.push(preview); + } + } + + if (imagesChanged) { + this.imageIndex = null; + if (imgURLs.length > 0) { Promise.all( - output.images.map((src) => { + imgURLs.map((src) => { return new Promise((r) => { const img = new Image(); img.onload = () => r(img); img.onerror = () => r(null); - img.src = "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); + img.src = src }); }) ).then((imgs) => { - if (this.images === output.images) { + if ((!output || this.images === output.images) && (!preview || this.preview === preview)) { this.imgs = imgs.filter(Boolean); this.setSizeForImage?.(); app.graph.setDirtyCanvas(true); } }); } + else { + this.imgs = null; + } } if (this.imgs && this.imgs.length) { @@ -901,17 +930,20 @@ export class ComfyApp { this.progress = null; this.runningNodeId = detail; this.graph.setDirtyCanvas(true, false); + delete this.nodePreviewImages[this.runningNodeId] }); api.addEventListener("executed", ({ detail }) => { this.nodeOutputs[detail.node] = detail.output; const node = this.graph.getNodeById(detail.node); - if (node?.onExecuted) { - node.onExecuted(detail.output); + if (node) { + if (node.onExecuted) + node.onExecuted(detail.output); } }); api.addEventListener("execution_start", ({ detail }) => { + this.runningNodeId = null; this.lastExecutionError = null }); @@ -922,6 +954,16 @@ export class ComfyApp { this.canvas.draw(true, true); }); + api.addEventListener("b_preview", ({ detail }) => { + const id = this.runningNodeId + if (id == null) + return; + + const blob = detail + const blobUrl = URL.createObjectURL(blob) + this.nodePreviewImages[id] = [blobUrl] + }); + api.init(); } @@ -1465,8 +1507,10 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.nodePreviewImages = {} this.lastPromptError = null; this.lastExecutionError = null; + this.runningNodeId = null; } } From 1c40296d747475f0338cba5b00e4b49b62f37b97 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 21:15:06 -0500 Subject: [PATCH 55/82] Fix --- web/scripts/api.js | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/web/scripts/api.js b/web/scripts/api.js index 780c74b30..8313f1abe 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -75,7 +75,6 @@ class ComfyApi extends EventTarget { const view = new DataView(event.data); const eventType = view.getUint32(0); const buffer = event.data.slice(4); - console.error("BINARY", eventType); switch (eventType) { case 1: const view2 = new DataView(event.data); @@ -89,8 +88,8 @@ class ComfyApi extends EventTarget { case 2: imageMime = "image/png" } - const jpegBlob = new Blob([buffer.slice(4)], { type: imageMime }); - this.dispatchEvent(new CustomEvent("b_preview", { detail: jpegBlob })); + const imageBlob = new Blob([buffer.slice(4)], { type: imageMime }); + this.dispatchEvent(new CustomEvent("b_preview", { detail: imageBlob })); break; default: throw new Error(`Unknown binary websocket message of type ${eventType}`); From 38bc02bb4085434f004775a3959a33cbfbc8ee80 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 21:21:59 -0500 Subject: [PATCH 56/82] Fix --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 908ff7af7..9f1c89282 100644 --- a/main.py +++ b/main.py @@ -41,10 +41,10 @@ async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) def hijack_progress(server): - def hook(value, total, preview_bytes_jpeg): + def hook(value, total, preview_image_bytes): server.send_sync("progress", { "value": value, "max": total}, server.client_id) - if preview_bytes_jpeg is not None: - server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes_jpeg, server.client_id) + if preview_image_bytes is not None: + server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) pass comfy.utils.set_progress_bar_global_hook(hook) From a9fa2d3727cf6e8d590e9897a9652b977f5c84ab Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 22:25:39 -0500 Subject: [PATCH 57/82] Fix --- folder_paths.py | 1 + 1 file changed, 1 insertion(+) diff --git a/folder_paths.py b/folder_paths.py index a1bf1444d..387299284 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -18,6 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision" folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) +folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) From f326a0a4680526db4681c9eb83d817f8b5272373 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 23:43:22 -0500 Subject: [PATCH 58/82] Make new LATENT_PREVIEWER type for declaring KSampler preview methods --- nodes.py | 52 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 13 deletions(-) diff --git a/nodes.py b/nodes.py index a80f81933..ec9d99845 100644 --- a/nodes.py +++ b/nodes.py @@ -34,6 +34,11 @@ import importlib import folder_paths +class LatentPreviewer: + def decode_latent_to_preview(self, device, x0): + pass + + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -282,6 +287,27 @@ class TAESDEncode: samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device) return ({"samples": samples}, ) +class TAESDPreviewerImpl(LatentPreviewer): + def __init__(self, taesd): + self.taesd = taesd + + def decode_latent_to_preview(self, device, x0): + x_sample = self.taesd.decoder(x0.to(device))[0].detach() + x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2] + x_sample = x_sample * 0.5 + return x_sample + +class TAESDPreviewer: + @classmethod + def INPUT_TYPES(s): + return {"required": { "taesd": ("TAESD", ), }} + RETURN_TYPES = ("LATENT_PREVIEWER",) + FUNCTION = "make_previewer" + + CATEGORY = "latent/previewer" + + def make_previewer(self, taesd): + return (TAESDPreviewerImpl(taesd), ) class SaveLatent: def __init__(self): @@ -986,10 +1012,8 @@ class SetLatentNoiseMask: return (s,) -def decode_latent_to_preview_image(taesd, device, preview_format, x0): - x_sample = taesd.decoder(x0.to(device))[0].detach() - x_sample = taesd.unscale_latents(x_sample) # returns value in [-2, 2] - x_sample = x_sample * 0.5 +def decode_latent_to_preview_image(previewer, device, preview_format, x0): + x_sample = previewer.decode_latent_to_preview(device, x0) x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -1015,7 +1039,7 @@ def decode_latent_to_preview_image(taesd, device, preview_format, x0): return preview_bytes -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, taesd=None): +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -1036,8 +1060,8 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): preview_bytes = None - if taesd: - preview_bytes = decode_latent_to_preview_image(taesd, device, preview_format, x0) + if previewer: + preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0) pbar.update_absolute(step + 1, total_steps, preview_bytes) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, @@ -1063,7 +1087,7 @@ class KSampler: "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, "optional": { - "taesd": ("TAESD",) + "previewer": ("LATENT_PREVIEWER",) }} RETURN_TYPES = ("LATENT",) @@ -1071,8 +1095,8 @@ class KSampler: CATEGORY = "sampling" - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, taesd=None): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, taesd=taesd) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None): + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer) class KSamplerAdvanced: @classmethod @@ -1093,7 +1117,7 @@ class KSamplerAdvanced: "return_with_leftover_noise": (["disable", "enable"], ), }, "optional": { - "taesd": ("TAESD",) + "previewer": ("LATENT_PREVIEWER",) }} RETURN_TYPES = ("LATENT",) @@ -1101,14 +1125,14 @@ class KSamplerAdvanced: CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, taesd=None): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, taesd=taesd) + return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer) class SaveImage: def __init__(self): @@ -1369,6 +1393,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "TAESDDecode": TAESDDecode, "TAESDEncode": TAESDEncode, + "TAESDPreviewer": TAESDPreviewer, "TAESDLoader": TAESDLoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, @@ -1425,6 +1450,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", "TAESDLoader": "Load TAESD", + "TAESDPreviewer": "TAESD Previewer", "LoraLoader": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", From 48f7ec750c6c432ac3156fa4aeadd801242ed1e8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 13:19:02 -0500 Subject: [PATCH 59/82] Make previews into cli option --- comfy/cli_args.py | 36 ++++++++++++++++ main.py | 1 - nodes.py | 104 +++++++++------------------------------------ web/scripts/app.js | 2 +- 4 files changed, 58 insertions(+), 85 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index cc4709f70..fdb2a34df 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -1,4 +1,35 @@ import argparse +import enum + + +class EnumAction(argparse.Action): + """ + Argparse action for handling Enums + """ + def __init__(self, **kwargs): + # Pop off the type value + enum_type = kwargs.pop("type", None) + + # Ensure an Enum subclass is provided + if enum_type is None: + raise ValueError("type must be assigned an Enum when using EnumAction") + if not issubclass(enum_type, enum.Enum): + raise TypeError("type must be an Enum when using EnumAction") + + # Generate choices from the Enum + choices = tuple(e.value for e in enum_type) + kwargs.setdefault("choices", choices) + kwargs.setdefault("metavar", f"[{','.join(list(choices))}]") + + super(EnumAction, self).__init__(**kwargs) + + self._enum = enum_type + + def __call__(self, parser, namespace, values, option_string=None): + # Convert value back into an Enum + value = self._enum(values) + setattr(namespace, self.dest, value) + parser = argparse.ArgumentParser() @@ -13,6 +44,11 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") +class PreviewType(enum.Enum): + TAESD = "taesd" +parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") +parser.add_argument("--default-preview-method", type=str, default=PreviewType.TAESD, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") + attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") diff --git a/main.py b/main.py index 9f1c89282..15f75f892 100644 --- a/main.py +++ b/main.py @@ -45,7 +45,6 @@ def hijack_progress(server): server.send_sync("progress", { "value": value, "max": total}, server.client_id) if preview_image_bytes is not None: server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) - pass comfy.utils.set_progress_bar_global_hook(hook) def cleanup_temp(): diff --git a/nodes.py b/nodes.py index ec9d99845..760747828 100644 --- a/nodes.py +++ b/nodes.py @@ -24,6 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils +from comfy.cli_args import args from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -180,21 +181,6 @@ class VAEDecodeTiled: def decode(self, vae, samples): return (vae.decode_tiled(samples["samples"]), ) -class TAESDDecode: - @classmethod - def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), "taesd": ("TAESD", )}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "decode" - - CATEGORY = "latent" - - def decode(self, taesd, samples): - device = comfy.model_management.get_torch_device() - # [B, C, H, W] -> [B, H, W, C] - pixels = taesd.decoder(samples["samples"].to(device)).permute(0, 2, 3, 1).detach().clamp(0, 1) - return (pixels, ) - class VAEEncode: @classmethod def INPUT_TYPES(s): @@ -272,21 +258,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) -class TAESDEncode: - @classmethod - def INPUT_TYPES(s): - return {"required": { "pixels": ("IMAGE", ), "taesd": ("TAESD", )}} - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "latent" - - def encode(self, taesd, pixels): - device = comfy.model_management.get_torch_device() - # [B, H, W, C] -> [B, C, H, W] - samples = taesd.encoder(pixels.permute(0, 3, 1, 2).to(device)).to(device) - return ({"samples": samples}, ) - class TAESDPreviewerImpl(LatentPreviewer): def __init__(self, taesd): self.taesd = taesd @@ -297,18 +268,6 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = x_sample * 0.5 return x_sample -class TAESDPreviewer: - @classmethod - def INPUT_TYPES(s): - return {"required": { "taesd": ("TAESD", ), }} - RETURN_TYPES = ("LATENT_PREVIEWER",) - FUNCTION = "make_previewer" - - CATEGORY = "latent/previewer" - - def make_previewer(self, taesd): - return (TAESDPreviewerImpl(taesd), ) - class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -524,26 +483,6 @@ class VAELoader: vae = comfy.sd.VAE(ckpt_path=vae_path) return (vae,) -class TAESDLoader: - @classmethod - def INPUT_TYPES(s): - model_list = folder_paths.get_filename_list("taesd") - return {"required": { - "encoder_name": (model_list, { "default": "taesd_encoder.pth" }), - "decoder_name": (model_list, { "default": "taesd_decoder.pth" }) - }} - RETURN_TYPES = ("TAESD",) - FUNCTION = "load_taesd" - - CATEGORY = "loaders" - - def load_taesd(self, encoder_name, decoder_name): - device = comfy.model_management.get_torch_device() - encoder_path = folder_paths.get_full_path("taesd", encoder_name) - decoder_path = folder_paths.get_full_path("taesd", decoder_name) - taesd = TAESD(encoder_path, decoder_path).to(device) - return (taesd,) - class ControlNetLoader: @classmethod def INPUT_TYPES(s): @@ -1039,7 +978,7 @@ def decode_latent_to_preview_image(previewer, device, preview_format, x0): return preview_bytes -def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, previewer=None): +def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -1057,6 +996,17 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" + previewer = None + if not args.disable_previews: + # TODO previewer methods + encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") + decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") + if encoder_path and decoder_path: + taesd = TAESD(encoder_path, decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") + pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): preview_bytes = None @@ -1085,18 +1035,16 @@ class KSampler: "negative": ("CONDITIONING", ), "latent_image": ("LATENT", ), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - }, - "optional": { - "previewer": ("LATENT_PREVIEWER",) - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, previewer=None): - return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, previewer=previewer) + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): + return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) class KSamplerAdvanced: @classmethod @@ -1115,24 +1063,22 @@ class KSamplerAdvanced: "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), "return_with_leftover_noise": (["disable", "enable"], ), - }, - "optional": { - "previewer": ("LATENT_PREVIEWER",) - }} + } + } RETURN_TYPES = ("LATENT",) FUNCTION = "sample" CATEGORY = "sampling" - def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, previewer=None): + def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): force_full_denoise = True if return_with_leftover_noise == "enable": force_full_denoise = False disable_noise = False if add_noise == "disable": disable_noise = True - return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise, previewer=previewer) + return common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise) class SaveImage: def __init__(self): @@ -1391,10 +1337,6 @@ NODE_CLASS_MAPPINGS = { "VAEEncode": VAEEncode, "VAEEncodeForInpaint": VAEEncodeForInpaint, "VAELoader": VAELoader, - "TAESDDecode": TAESDDecode, - "TAESDEncode": TAESDEncode, - "TAESDPreviewer": TAESDPreviewer, - "TAESDLoader": TAESDLoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, "LatentUpscaleBy": LatentUpscaleBy, @@ -1449,8 +1391,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint (With Config)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", - "TAESDLoader": "Load TAESD", - "TAESDPreviewer": "TAESD Previewer", "LoraLoader": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", @@ -1473,8 +1413,6 @@ NODE_DISPLAY_NAME_MAPPINGS = { "SetLatentNoiseMask": "Set Latent Noise Mask", "VAEDecode": "VAE Decode", "VAEEncode": "VAE Encode", - "TAESDDecode": "TAESD Decode", - "TAESDEncode": "TAESD Encode", "LatentRotate": "Rotate Latent", "LatentFlip": "Flip Latent", "LatentCrop": "Crop Latent", diff --git a/web/scripts/app.js b/web/scripts/app.js index 495d43e1f..9df94c9eb 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -382,7 +382,7 @@ export class ComfyApp { this.images = output.images; imagesChanged = true; imgURLs = imgURLs.concat(output.images.map(params => { - return "/view?" + new URLSearchParams(src).toString() + app.getPreviewFormatParam(); + return "/view?" + new URLSearchParams(params).toString() + app.getPreviewFormatParam(); })) } } From 70d72c4336040de7fb47a61fa048af3e9fe632b5 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 15:26:56 -0500 Subject: [PATCH 60/82] Slightly less vibrant sample --- nodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nodes.py b/nodes.py index 760747828..d11b4ae69 100644 --- a/nodes.py +++ b/nodes.py @@ -264,8 +264,8 @@ class TAESDPreviewerImpl(LatentPreviewer): def decode_latent_to_preview(self, device, x0): x_sample = self.taesd.decoder(x0.to(device))[0].detach() - x_sample = self.taesd.unscale_latents(x_sample) # returns value in [-2, 2] - x_sample = x_sample * 0.5 + # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] + x_sample = x_sample.sub(0.5).mul(2) return x_sample class SaveLatent: From d5a28fadaa000c300cf9490c1f92ba5275871a30 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:39:56 -0500 Subject: [PATCH 61/82] Add latent2rgb preview --- comfy/cli_args.py | 5 ++-- comfy/utils.py | 3 +++ nodes.py | 61 ++++++++++++++++++++++++++++++++--------------- 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fdb2a34df..fae666127 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -44,10 +44,11 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") -class PreviewType(enum.Enum): +class LatentPreviewType(enum.Enum): + Latent2RGB = "latent2rgb" TAESD = "taesd" parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") -parser.add_argument("--default-preview-method", type=str, default=PreviewType.TAESD, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") +parser.add_argument("--default-preview-method", type=str, default=LatentPreviewType.Latent2RGB, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/utils.py b/comfy/utils.py index 291c62e42..08944ade3 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,7 @@ import torch import math import struct +import comfy.model_management def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -166,6 +167,8 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): + comfy.model_management.throw_exception_if_processing_interrupted() + s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).cpu() diff --git a/nodes.py b/nodes.py index d11b4ae69..74c664bdc 100644 --- a/nodes.py +++ b/nodes.py @@ -24,7 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils -from comfy.cli_args import args +from comfy.cli_args import args, LatentPreviewType from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -40,6 +40,27 @@ class LatentPreviewer: pass +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self): + self.latent_rgb_factors = torch.tensor([ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ], device="cpu") + + def decode_latent_to_preview(self, device, x0): + latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -266,7 +287,13 @@ class TAESDPreviewerImpl(LatentPreviewer): x_sample = self.taesd.decoder(x0.to(device))[0].detach() # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] x_sample = x_sample.sub(0.5).mul(2) - return x_sample + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + + preview_image = Image.fromarray(x_sample) + return preview_image class SaveLatent: def __init__(self): @@ -952,16 +979,8 @@ class SetLatentNoiseMask: def decode_latent_to_preview_image(previewer, device, preview_format, x0): - x_sample = previewer.decode_latent_to_preview(device, x0) - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - - preview_image = Image.fromarray(x_sample) - - if preview_image.size[0] > MAX_PREVIEW_RESOLUTION or preview_image.size[1] > MAX_PREVIEW_RESOLUTION: - preview_image.thumbnail((MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) + preview_image = previewer.decode_latent_to_preview(device, x0) + preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) preview_type = 1 if preview_format == "JPEG": @@ -999,13 +1018,17 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, previewer = None if not args.disable_previews: # TODO previewer methods - encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") - decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") - if encoder_path and decoder_path: - taesd = TAESD(encoder_path, decoder_path).to(device) - previewer = TAESDPreviewerImpl(taesd) - else: - print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") + if args.default_preview_method == LatentPreviewType.TAESD: + encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") + decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") + if encoder_path and decoder_path: + taesd = TAESD(encoder_path, decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") + + if previewer is None: + previewer = Latent2RGBPreviewer() pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): From 3e17971acbbdd45593403b72f23fb66d703d1abb Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 18:59:10 -0500 Subject: [PATCH 62/82] preview method autodetection --- comfy/cli_args.py | 5 +++-- ...esd_encoder_pth_and_taesd_decoder_pth_here | 0 nodes.py | 20 +++++++++++++------ 3 files changed, 17 insertions(+), 8 deletions(-) create mode 100644 models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here diff --git a/comfy/cli_args.py b/comfy/cli_args.py index fae666127..3e6b1daa6 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -44,11 +44,12 @@ parser.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") -class LatentPreviewType(enum.Enum): +class LatentPreviewMethod(enum.Enum): + Auto = "auto" Latent2RGB = "latent2rgb" TAESD = "taesd" parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") -parser.add_argument("--default-preview-method", type=str, default=LatentPreviewType.Latent2RGB, metavar="PREVIEW_TYPE", help="Default preview method for sampler nodes.") +parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 74c664bdc..6266b6c0d 100644 --- a/nodes.py +++ b/nodes.py @@ -24,7 +24,7 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils -from comfy.cli_args import args, LatentPreviewType +from comfy.cli_args import args, LatentPreviewMethod from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -1018,11 +1018,19 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, previewer = None if not args.disable_previews: # TODO previewer methods - if args.default_preview_method == LatentPreviewType.TAESD: - encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") - decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") - if encoder_path and decoder_path: - taesd = TAESD(encoder_path, decoder_path).to(device) + taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") + taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") + + method = args.default_preview_method + + if args.default_preview_method == LatentPreviewMethod.AUTO: + method = LatentPreviewMethod.Latent2RGB + if taesd_encoder_path and taesd_encoder_path: + method = LatentPreviewMethod.TAESD + + if method == LatentPreviewMethod.TAESD: + if taesd_encoder_path and taesd_encoder_path: + taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device) previewer = TAESDPreviewerImpl(taesd) else: print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") From 8b4a6c19c2f972bf2701921f16cc6f9729659fe9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 19:00:51 -0500 Subject: [PATCH 63/82] Fix --- nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nodes.py b/nodes.py index 6266b6c0d..971b5c3b8 100644 --- a/nodes.py +++ b/nodes.py @@ -1023,7 +1023,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, method = args.default_preview_method - if args.default_preview_method == LatentPreviewMethod.AUTO: + if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB if taesd_encoder_path and taesd_encoder_path: method = LatentPreviewMethod.TAESD From 2b2ea5194e04a60130ea0d41778bd915bb157b40 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Mon, 5 Jun 2023 19:16:51 -0500 Subject: [PATCH 64/82] Add readme note --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index bfa8904df..6e0803ab7 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin - [Upscale Models (ESRGAN, ESRGAN variants, SwinIR, Swin2SR, etc...)](https://comfyanonymous.github.io/ComfyUI_examples/upscale_models/) - [unCLIP Models](https://comfyanonymous.github.io/ComfyUI_examples/unclip/) - [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/) +- Latent previews with [TAESD](https://github.com/madebyollin/taesd) - Starts up very fast. - Works fully offline: will never download anything. - [Config file](extra_model_paths.yaml.example) to set the search paths for models. @@ -181,6 +182,10 @@ You can set this command line setting to disable the upcasting to fp32 in some c ```--dont-upcast-attention``` +## How to show high-quality previews? + +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews. + ## Support and dev channel [Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). From a3a713b6c581f4c0487c58c5a20eca2a5e8e6bde Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Jun 2023 01:26:52 -0400 Subject: [PATCH 65/82] Refactor previews into one command line argument. Clean up a few things. --- README.md | 4 +- comfy/cli_args.py | 5 +- comfy/taesd/taesd.py | 4 +- comfy/utils.py | 3 - folder_paths.py | 2 +- latent_preview.py | 95 +++++++++++++++++++ ...esd_encoder_pth_and_taesd_decoder_pth_here | 0 nodes.py | 94 +----------------- 8 files changed, 107 insertions(+), 100 deletions(-) create mode 100644 latent_preview.py rename models/{taesd => vae_approx}/put_taesd_encoder_pth_and_taesd_decoder_pth_here (100%) diff --git a/README.md b/README.md index 6e0803ab7..d998afe65 100644 --- a/README.md +++ b/README.md @@ -184,7 +184,9 @@ You can set this command line setting to disable the upcasting to fp32 in some c ## How to show high-quality previews? -The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/taesd` folder. Once they're installed, restart ComfyUI to enable high-quality previews. +Use ```--preview-method auto``` to enable previews. + +The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_encoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_encoder.pth) and [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews. ## Support and dev channel diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 3e6b1daa6..b56497de0 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -45,11 +45,12 @@ parser.add_argument("--force-fp32", action="store_true", help="Force fp32 (If th parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") class LatentPreviewMethod(enum.Enum): + NoPreviews = "none" Auto = "auto" Latent2RGB = "latent2rgb" TAESD = "taesd" -parser.add_argument("--disable-previews", action="store_true", help="Disable showing node previews.") -parser.add_argument("--default-preview-method", type=str, default=LatentPreviewMethod.Auto, metavar="PREVIEW_METHOD", help="Default preview method for sampler nodes.") + +parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization instead of the sub-quadratic one. Ignored when xformers is used.") diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py index e64067454..1549345ae 100644 --- a/comfy/taesd/taesd.py +++ b/comfy/taesd/taesd.py @@ -50,9 +50,9 @@ class TAESD(nn.Module): self.encoder = Encoder() self.decoder = Decoder() if encoder_path is not None: - self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu")) + self.encoder.load_state_dict(torch.load(encoder_path, map_location="cpu", weights_only=True)) if decoder_path is not None: - self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu")) + self.decoder.load_state_dict(torch.load(decoder_path, map_location="cpu", weights_only=True)) @staticmethod def scale_latents(x): diff --git a/comfy/utils.py b/comfy/utils.py index 08944ade3..291c62e42 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,7 +1,6 @@ import torch import math import struct -import comfy.model_management def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -167,8 +166,6 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am out_div = torch.zeros((s.shape[0], out_channels, round(s.shape[2] * upscale_amount), round(s.shape[3] * upscale_amount)), device="cpu") for y in range(0, s.shape[2], tile_y - overlap): for x in range(0, s.shape[3], tile_x - overlap): - comfy.model_management.throw_exception_if_processing_interrupted() - s_in = s[:,:,y:y+tile_y,x:x+tile_x] ps = function(s_in).cpu() diff --git a/folder_paths.py b/folder_paths.py index 387299284..2ad1b1719 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -18,7 +18,7 @@ folder_names_and_paths["clip_vision"] = ([os.path.join(models_dir, "clip_vision" folder_names_and_paths["style_models"] = ([os.path.join(models_dir, "style_models")], supported_pt_extensions) folder_names_and_paths["embeddings"] = ([os.path.join(models_dir, "embeddings")], supported_pt_extensions) folder_names_and_paths["diffusers"] = ([os.path.join(models_dir, "diffusers")], ["folder"]) -folder_names_and_paths["taesd"] = ([os.path.join(models_dir, "taesd")], supported_pt_extensions) +folder_names_and_paths["vae_approx"] = ([os.path.join(models_dir, "vae_approx")], supported_pt_extensions) folder_names_and_paths["controlnet"] = ([os.path.join(models_dir, "controlnet"), os.path.join(models_dir, "t2i_adapter")], supported_pt_extensions) folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], supported_pt_extensions) diff --git a/latent_preview.py b/latent_preview.py new file mode 100644 index 000000000..ef6c201b6 --- /dev/null +++ b/latent_preview.py @@ -0,0 +1,95 @@ +import torch +from PIL import Image, ImageOps +from io import BytesIO +import struct +import numpy as np + +from comfy.cli_args import args, LatentPreviewMethod +from comfy.taesd.taesd import TAESD +import folder_paths + +MAX_PREVIEW_RESOLUTION = 512 + +class LatentPreviewer: + def decode_latent_to_preview(self, x0): + pass + + def decode_latent_to_preview_image(self, preview_format, x0): + preview_image = self.decode_latent_to_preview(x0) + preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) + + preview_type = 1 + if preview_format == "JPEG": + preview_type = 1 + elif preview_format == "PNG": + preview_type = 2 + + bytesIO = BytesIO() + header = struct.pack(">I", preview_type) + bytesIO.write(header) + preview_image.save(bytesIO, format=preview_format, quality=95) + preview_bytes = bytesIO.getvalue() + return preview_bytes + +class TAESDPreviewerImpl(LatentPreviewer): + def __init__(self, taesd): + self.taesd = taesd + + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decoder(x0)[0].detach() + # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] + x_sample = x_sample.sub(0.5).mul(2) + + x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + x_sample = x_sample.astype(np.uint8) + + preview_image = Image.fromarray(x_sample) + return preview_image + + +class Latent2RGBPreviewer(LatentPreviewer): + def __init__(self): + self.latent_rgb_factors = torch.tensor([ + # R G B + [0.298, 0.207, 0.208], # L1 + [0.187, 0.286, 0.173], # L2 + [-0.158, 0.189, 0.264], # L3 + [-0.184, -0.271, -0.473], # L4 + ], device="cpu") + + def decode_latent_to_preview(self, x0): + latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors + + latents_ubyte = (((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte()).cpu() + + return Image.fromarray(latents_ubyte.numpy()) + + +def get_previewer(device): + previewer = None + method = args.preview_method + if method != LatentPreviewMethod.NoPreviews: + # TODO previewer methods + taesd_decoder_path = folder_paths.get_full_path("vae_approx", "taesd_decoder.pth") + + if method == LatentPreviewMethod.Auto: + method = LatentPreviewMethod.Latent2RGB + if taesd_decoder_path: + method = LatentPreviewMethod.TAESD + + if method == LatentPreviewMethod.TAESD: + if taesd_decoder_path: + taesd = TAESD(None, taesd_decoder_path).to(device) + previewer = TAESDPreviewerImpl(taesd) + else: + print("Warning: TAESD previews enabled, but could not find models/vae_approx/taesd_decoder.pth") + + if previewer is None: + previewer = Latent2RGBPreviewer() + return previewer + + diff --git a/models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here b/models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here similarity index 100% rename from models/taesd/put_taesd_encoder_pth_and_taesd_decoder_pth_here rename to models/vae_approx/put_taesd_encoder_pth_and_taesd_decoder_pth_here diff --git a/nodes.py b/nodes.py index 971b5c3b8..b057504ed 100644 --- a/nodes.py +++ b/nodes.py @@ -7,15 +7,12 @@ import hashlib import traceback import math import time -import struct -from io import BytesIO from PIL import Image, ImageOps from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch - sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -24,8 +21,6 @@ import comfy.samplers import comfy.sample import comfy.sd import comfy.utils -from comfy.cli_args import args, LatentPreviewMethod -from comfy.taesd.taesd import TAESD import comfy.clip_vision @@ -33,33 +28,7 @@ import comfy.model_management import importlib import folder_paths - - -class LatentPreviewer: - def decode_latent_to_preview(self, device, x0): - pass - - -class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self): - self.latent_rgb_factors = torch.tensor([ - # R G B - [0.298, 0.207, 0.208], # L1 - [0.187, 0.286, 0.173], # L2 - [-0.158, 0.189, 0.264], # L3 - [-0.184, -0.271, -0.473], # L4 - ], device="cpu") - - def decode_latent_to_preview(self, device, x0): - latent_image = x0[0].permute(1, 2, 0).cpu() @ self.latent_rgb_factors - - latents_ubyte = (((latent_image + 1) / 2) - .clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - .byte()).cpu() - - return Image.fromarray(latents_ubyte.numpy()) - +import latent_preview def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() @@ -68,7 +37,6 @@ def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) MAX_RESOLUTION=8192 -MAX_PREVIEW_RESOLUTION = 512 class CLIPTextEncode: @classmethod @@ -279,22 +247,6 @@ class VAEEncodeForInpaint: return ({"samples":t, "noise_mask": (mask_erosion[:,:,:x,:y].round())}, ) -class TAESDPreviewerImpl(LatentPreviewer): - def __init__(self, taesd): - self.taesd = taesd - - def decode_latent_to_preview(self, device, x0): - x_sample = self.taesd.decoder(x0.to(device))[0].detach() - # x_sample = self.taesd.unscale_latents(x_sample).div(4).add(0.5) # returns value in [-2, 2] - x_sample = x_sample.sub(0.5).mul(2) - - x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) - x_sample = x_sample.astype(np.uint8) - - preview_image = Image.fromarray(x_sample) - return preview_image - class SaveLatent: def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -978,25 +930,6 @@ class SetLatentNoiseMask: return (s,) -def decode_latent_to_preview_image(previewer, device, preview_format, x0): - preview_image = previewer.decode_latent_to_preview(device, x0) - preview_image = ImageOps.contain(preview_image, (MAX_PREVIEW_RESOLUTION, MAX_PREVIEW_RESOLUTION), Image.ANTIALIAS) - - preview_type = 1 - if preview_format == "JPEG": - preview_type = 1 - elif preview_format == "PNG": - preview_type = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", preview_type) - bytesIO.write(header) - preview_image.save(bytesIO, format=preview_format) - preview_bytes = bytesIO.getvalue() - - return preview_bytes - - def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): device = comfy.model_management.get_torch_device() latent_image = latent["samples"] @@ -1015,34 +948,13 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, if preview_format not in ["JPEG", "PNG"]: preview_format = "JPEG" - previewer = None - if not args.disable_previews: - # TODO previewer methods - taesd_encoder_path = folder_paths.get_full_path("taesd", "taesd_encoder.pth") - taesd_decoder_path = folder_paths.get_full_path("taesd", "taesd_decoder.pth") - - method = args.default_preview_method - - if method == LatentPreviewMethod.Auto: - method = LatentPreviewMethod.Latent2RGB - if taesd_encoder_path and taesd_encoder_path: - method = LatentPreviewMethod.TAESD - - if method == LatentPreviewMethod.TAESD: - if taesd_encoder_path and taesd_encoder_path: - taesd = TAESD(taesd_encoder_path, taesd_decoder_path).to(device) - previewer = TAESDPreviewerImpl(taesd) - else: - print("Warning: TAESD previews enabled, but could not find models/taesd/taesd_encoder.pth and models/taesd/taesd_decoder.pth") - - if previewer is None: - previewer = Latent2RGBPreviewer() + previewer = latent_preview.get_previewer(device) pbar = comfy.utils.ProgressBar(steps) def callback(step, x0, x, total_steps): preview_bytes = None if previewer: - preview_bytes = decode_latent_to_preview_image(previewer, device, preview_format, x0) + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) pbar.update_absolute(step + 1, total_steps, preview_bytes) samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, From 422163c2ba65b18d8208d4661d27c7e312e8d862 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Tue, 6 Jun 2023 22:27:44 +0900 Subject: [PATCH 66/82] bugfix: Fixing the calculation issue when an image widget is added to the size calculation of the text widget. --- web/scripts/app.js | 4 ++++ web/scripts/widgets.js | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9df94c9eb..27c67fb49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -365,6 +365,10 @@ export class ComfyApp { } node.prototype.setSizeForImage = function () { + if (this.inputHeight) { + this.setSize(this.size); + return; + } const minHeight = getImageTop(this) + 220; if (this.size[1] < minHeight) { this.setSize([this.size[0], minHeight]); diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index d6faaddbf..dfa26aef4 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -115,12 +115,12 @@ function addMultilineWidget(node, name, opts, app) { // See how large each text input can be freeSpace -= widgetHeight; - freeSpace /= multi.length; + freeSpace /= multi.length + (!!node.imgs?.length); if (freeSpace < MIN_SIZE) { // There isnt enough space for all the widgets, increase the size of the node freeSpace = MIN_SIZE; - node.size[1] = y + widgetHeight + freeSpace * multi.length; + node.size[1] = y + widgetHeight + freeSpace * (multi.length + (!!node.imgs?.length)); node.graph.setDirtyCanvas(true); } From 3b5b095d04f73f85b19f7c44aaffe481081f3818 Mon Sep 17 00:00:00 2001 From: reaper47 Date: Tue, 6 Jun 2023 17:40:07 +0200 Subject: [PATCH 67/82] Add .idea/ to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index df6adbe4b..8380a2f7c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ custom_nodes/ !custom_nodes/example_node.py.example extra_model_paths.yaml /.vs +.idea/ \ No newline at end of file From 0e425603fb8ba12f1e7d09a1f58127347a94de98 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 6 Jun 2023 03:25:49 -0400 Subject: [PATCH 68/82] Small refactor. --- comfy/sd.py | 21 +++++---------------- comfy/utils.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index 336fee4a6..04eaaa9fe 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -31,17 +31,6 @@ def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): if ids.dtype == torch.float32: sd['cond_stage_model.transformer.text_model.embeddings.position_ids'] = ids.round() - keys_to_replace = { - "cond_stage_model.model.positional_embedding": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight", - "cond_stage_model.model.token_embedding.weight": "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight", - "cond_stage_model.model.ln_final.weight": "cond_stage_model.transformer.text_model.final_layer_norm.weight", - "cond_stage_model.model.ln_final.bias": "cond_stage_model.transformer.text_model.final_layer_norm.bias", - } - - for x in keys_to_replace: - if x in sd: - sd[keys_to_replace[x]] = sd.pop(x) - sd = utils.transformers_convert(sd, "cond_stage_model.model", "cond_stage_model.transformer.text_model", 24) for x in load_state_dict_to: @@ -1073,13 +1062,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o "legacy": False } - if len(sd['model.diffusion_model.input_blocks.1.1.proj_in.weight'].shape) == 2: + if len(sd['model.diffusion_model.input_blocks.4.1.proj_in.weight'].shape) == 2: unet_config['use_linear_in_transformer'] = True unet_config["use_fp16"] = fp16 unet_config["model_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[0] unet_config["in_channels"] = sd['model.diffusion_model.input_blocks.0.0.weight'].shape[1] - unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] + unet_config["context_dim"] = sd['model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'].shape[1] sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} @@ -1097,10 +1086,10 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o else: sd_config["conditioning_key"] = "crossattn" - if unet_config["context_dim"] == 1024: - unet_config["num_head_channels"] = 64 #SD2.x - else: + if unet_config["context_dim"] == 768: unet_config["num_heads"] = 8 #SD1.x + else: + unet_config["num_head_channels"] = 64 #SD2.x unclip = 'model.diffusion_model.label_emb.0.0.weight' if unclip in sd_keys: diff --git a/comfy/utils.py b/comfy/utils.py index 291c62e42..585ebda51 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -24,6 +24,18 @@ def load_torch_file(ckpt, safe_load=False): return sd def transformers_convert(sd, prefix_from, prefix_to, number): + keys_to_replace = { + "{}.positional_embedding": "{}.embeddings.position_embedding.weight", + "{}.token_embedding.weight": "{}.embeddings.token_embedding.weight", + "{}.ln_final.weight": "{}.final_layer_norm.weight", + "{}.ln_final.bias": "{}.final_layer_norm.bias", + } + + for k in keys_to_replace: + x = k.format(prefix_from) + if x in sd: + sd[keys_to_replace[k].format(prefix_to)] = sd.pop(x) + resblock_to_replace = { "ln_1": "layer_norm1", "ln_2": "layer_norm2", From 5cf4079923c94ab3a507d89674bca3bf6f3dbc5b Mon Sep 17 00:00:00 2001 From: reaper47 Date: Wed, 7 Jun 2023 15:15:38 +0200 Subject: [PATCH 69/82] Give linux some love --- README.md | 51 ++++++++++++++++++++++++++------------------------- main.py | 20 +++++++++++--------- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index d998afe65..55cd25c72 100644 --- a/README.md +++ b/README.md @@ -38,28 +38,28 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git ## Shortcuts -| Keybind | Explanation | -| - | - | -| Ctrl + Enter | Queue up current graph for generation | -| Ctrl + Shift + Enter | Queue up current graph as first for generation | -| Ctrl + S | Save workflow | -| Ctrl + O | Load workflow | -| Ctrl + A | Select all nodes | -| Ctrl + M | Mute/unmute selected nodes | -| Delete/Backspace | Delete selected nodes | -| Ctrl + Delete/Backspace | Delete the current graph | -| Space | Move the canvas around when held and moving the cursor | -| Ctrl/Shift + Click | Add clicked node to selection | -| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | -| Ctrl + C/Ctrl + Shift + V| Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | -| Shift + Drag | Move multiple selected nodes at the same time | -| Ctrl + D | Load default graph | -| Q | Toggle visibility of the queue | -| H | Toggle visibility of history | -| R | Refresh graph | -| Double-Click LMB | Open node quick search palette | +| Keybind | Explanation | +|---------------------------|--------------------------------------------------------------------------------------------------------------------| +| Ctrl + Enter | Queue up current graph for generation | +| Ctrl + Shift + Enter | Queue up current graph as first for generation | +| Ctrl + S | Save workflow | +| Ctrl + O | Load workflow | +| Ctrl + A | Select all nodes | +| Ctrl + M | Mute/unmute selected nodes | +| Delete/Backspace | Delete selected nodes | +| Ctrl + Delete/Backspace | Delete the current graph | +| Space | Move the canvas around when held and moving the cursor | +| Ctrl/Shift + Click | Add clicked node to selection | +| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | +| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | +| Shift + Drag | Move multiple selected nodes at the same time | +| Ctrl + D | Load default graph | +| Q | Toggle visibility of the queue | +| H | Toggle visibility of history | +| R | Refresh graph | +| Double-Click LMB | Open node quick search palette | -Ctrl can also be replaced with Cmd instead for MacOS users +Ctrl can also be replaced with Cmd instead for macOS users # Installing @@ -77,7 +77,8 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo ## Colab Notebook -To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb) +To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: +Link to open with Google colab ## Manual Install (Windows, Linux) @@ -125,7 +126,7 @@ Mac/MPS: There is basic support in the code but until someone makes some install ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? -You don't. If you have another UI installed and working with it's own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: +You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: ```source path_to_other_sd_gui/venv/bin/activate``` @@ -135,7 +136,7 @@ With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"``` With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"``` -And then you can use that terminal to run Comfyui without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. +And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI. # Running @@ -190,7 +191,7 @@ The default installation includes a fast latent preview method that's low-resolu ## Support and dev channel -[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). +Matrix space: #comfyui_space:matrix.org (it's like discord but open source). # QA diff --git a/main.py b/main.py index 15f75f892..8293c06fc 100644 --- a/main.py +++ b/main.py @@ -37,21 +37,25 @@ def prompt_worker(q, server): e.execute(item[2], item[1], item[3], item[4]) q.task_done(item_id, e.outputs_ui) + async def run(server, address='', port=8188, verbose=True, call_on_start=None): await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) + def hijack_progress(server): def hook(value, total, preview_image_bytes): - server.send_sync("progress", { "value": value, "max": total}, server.client_id) + server.send_sync("progress", {"value": value, "max": total}, server.client_id) if preview_image_bytes is not None: server.send_sync(BinaryEventTypes.PREVIEW_IMAGE, preview_image_bytes, server.client_id) comfy.utils.set_progress_bar_global_hook(hook) + def cleanup_temp(): temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") if os.path.exists(temp_dir): shutil.rmtree(temp_dir, ignore_errors=True) + def load_extra_path_config(yaml_path): with open(yaml_path, 'r') as stream: config = yaml.safe_load(stream) @@ -72,6 +76,7 @@ def load_extra_path_config(yaml_path): print("Adding extra search path", x, full_path) folder_paths.add_model_folder_path(x, full_path) + if __name__ == "__main__": cleanup_temp() @@ -92,7 +97,7 @@ if __name__ == "__main__": server.add_routes() hijack_progress(server) - threading.Thread(target=prompt_worker, daemon=True, args=(q,server,)).start() + threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start() if args.output_directory: output_dir = os.path.abspath(args.output_directory) @@ -106,15 +111,12 @@ if __name__ == "__main__": if args.auto_launch: def startup_server(address, port): import webbrowser - webbrowser.open("http://{}:{}".format(address, port)) + webbrowser.open(f"http://{address}:{port}") call_on_start = startup_server - if os.name == "nt": - try: - loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) - except KeyboardInterrupt: - pass - else: + try: loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)) + except KeyboardInterrupt: + print("\nStopped server") cleanup_temp() From 70e02b443f6b803ba4d08aa9cc0f0286f9e5dd2b Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Wed, 7 Jun 2023 22:56:08 +0900 Subject: [PATCH 70/82] robust patch on pasteFromClipspace --- web/scripts/app.js | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 27c67fb49..657ea0246 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -125,10 +125,14 @@ export class ComfyApp { if(ComfyApp.clipspace.imgs && node.imgs) { if(node.images && ComfyApp.clipspace.images) { if(ComfyApp.clipspace['img_paste_mode'] == 'selected') { - app.nodeOutputs[node.id + ""].images = node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; + node.images = [ComfyApp.clipspace.images[ComfyApp.clipspace['selectedIndex']]]; } - else - app.nodeOutputs[node.id + ""].images = node.images = ComfyApp.clipspace.images; + else { + node.images = ComfyApp.clipspace.images; + } + + if(app.nodeOutputs[node.id + ""]) + app.nodeOutputs[node.id + ""].images = node.images; } if(ComfyApp.clipspace.imgs) { From 28677342c1d2f2eb86aadd2e8fceac9c2f6196bc Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Thu, 8 Jun 2023 00:06:56 +0900 Subject: [PATCH 71/82] robust paste for image --- web/scripts/app.js | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 657ea0246..385a54579 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -165,7 +165,16 @@ export class ComfyApp { if(ComfyApp.clipspace.widgets) { ComfyApp.clipspace.widgets.forEach(({ type, name, value }) => { const prop = Object.values(node.widgets).find(obj => obj.type === type && obj.name === name); - if (prop && prop.type != 'button') { + if (prop && prop.type != 'image') { + if(typeof prop.value == "string" && value.filename) { + prop.value = (value.subfolder?value.subfolder+'/':'') + value.filename + (value.type?` [${value.type}]`:''); + } + else { + prop.value = value; + prop.callback(value); + } + } + else if (prop && prop.type != 'button') { prop.value = value; prop.callback(value); } From 29c50954eace60fe8422c63ff4bf57c389a1ba76 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Jun 2023 02:00:44 -0400 Subject: [PATCH 72/82] Add some quick instructions how to use directml. --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 55cd25c72..78f34a9bb 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,7 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo ## Colab Notebook -To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: -Link to open with Google colab +To run it on colab or paperspace you can use my [Colab Notebook](notebooks/comfyui_colab.ipynb) here: [Link to open with google colab](https://colab.research.google.com/github/comfyanonymous/ComfyUI/blob/master/notebooks/comfyui_colab.ipynb) ## Manual Install (Windows, Linux) @@ -124,6 +123,9 @@ After this you should have everything installed and can proceed to running Comfy Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own. +Directml: ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` + + ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it: @@ -191,7 +193,7 @@ The default installation includes a fast latent preview method that's low-resolu ## Support and dev channel -Matrix space: #comfyui_space:matrix.org (it's like discord but open source). +[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source). # QA From eed4f62cc5db387c6389893b64ddc18aadf0a04d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 8 Jun 2023 12:08:00 -0500 Subject: [PATCH 73/82] Add comment support to dynamic prompts nodes --- web/extensions/core/dynamicPrompts.js | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/web/extensions/core/dynamicPrompts.js b/web/extensions/core/dynamicPrompts.js index 7dae07f4d..599a9e685 100644 --- a/web/extensions/core/dynamicPrompts.js +++ b/web/extensions/core/dynamicPrompts.js @@ -3,6 +3,13 @@ import { app } from "../../scripts/app.js"; // Allows for simple dynamic prompt replacement // Inputs in the format {a|b} will have a random value of a or b chosen when the prompt is queued. +/* + * Strips C-style line and block comments from a string + */ +function stripComments(str) { + return str.replace(/\/\*[\s\S]*?\*\/|\/\/.*/g,''); +} + app.registerExtension({ name: "Comfy.DynamicPrompts", nodeCreated(node) { @@ -15,7 +22,7 @@ app.registerExtension({ for (const widget of widgets) { // Override the serialization of the value to resolve dynamic prompts for all widgets supporting it in this node widget.serializeValue = (workflowNode, widgetIndex) => { - let prompt = widget.value; + let prompt = stripComments(widget.value); while (prompt.replace("\\{", "").includes("{") && prompt.replace("\\}", "").includes("}")) { const startIndex = prompt.replace("\\{", "00").indexOf("{"); const endIndex = prompt.replace("\\}", "00").indexOf("}"); From 65922419e2c12be6333a2d2889106eb6f250beeb Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 8 Jun 2023 12:12:07 -0500 Subject: [PATCH 74/82] Add comment note in README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 78f34a9bb..d9083b7e1 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,8 @@ You can use () to change emphasis of a word or phrase like: (good code:1.2) or ( You can use {day|night}, for wildcard/dynamic prompts. With this syntax "{wild|card|test}" will be randomly replaced by either "wild", "card" or "test" by the frontend every time you queue the prompt. To use {} characters in your actual prompt escape them like: \\{ or \\}. +Dynamic prompts also support C-style comments, like `// comment` or `/* comment */`. + To use a textual inversion concepts/embeddings in a text prompt put them in the models/embeddings directory and use them in the CLIPTextEncode node like this (you can omit the .pt extension): ```embedding:embedding_filename.pt``` From 23cf8ca7c52ef2abb86c820ee751bbafe4d3e6ed Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 8 Jun 2023 23:48:14 -0400 Subject: [PATCH 75/82] Fix bug when embedding gets ignored because of mismatched size. --- comfy/sd1_clip.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index b1a392736..91fb4ff27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -82,6 +82,8 @@ class SD1ClipModel(torch.nn.Module, ClipTokenWeightEncoder): next_new_token += 1 else: print("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored", y.shape[0], current_embeds.weight.shape[1]) + while len(tokens_temp) < len(x): + tokens_temp += [self.empty_tokens[0][-1]] out_tokens += [tokens_temp] if len(embedding_weights) > 0: From 8e14c46a381f2cd5429c0b82d5766816d9a58282 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Fri, 9 Jun 2023 15:21:30 +0900 Subject: [PATCH 76/82] allows connect primitive node to reroute if primitive node has type (#751) Co-authored-by: Lt.Dr.Data --- web/extensions/core/widgetInputs.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 4fe0a6013..8955fca87 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -240,6 +240,7 @@ app.registerExtension({ // No widget, we cant connect if (!input.widget) { + if (this.outputs[0]?.type != '*' && target_node.type == "Reroute") return true; if (!(input.type in ComfyWidgets)) return false; } From 4b0b516544d5f4896acaa574aae18cba798a2be8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Jun 2023 02:48:42 -0400 Subject: [PATCH 77/82] Add code to handle primitive nodes connected to reroute nodes. Revert last commit because I noticed it broke a few things. --- web/extensions/core/widgetInputs.js | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/web/extensions/core/widgetInputs.js b/web/extensions/core/widgetInputs.js index 8955fca87..c356655b0 100644 --- a/web/extensions/core/widgetInputs.js +++ b/web/extensions/core/widgetInputs.js @@ -200,8 +200,23 @@ app.registerExtension({ applyToGraph() { if (!this.outputs[0].links?.length) return; + function get_links(node) { + let links = []; + for (const l of node.outputs[0].links) { + const linkInfo = app.graph.links[l]; + const n = node.graph.getNodeById(linkInfo.target_id); + if (n.type == "Reroute") { + links = links.concat(get_links(n)); + } else { + links.push(l); + } + } + return links; + } + + let links = get_links(this); // For each output link copy our value over the original widget value - for (const l of this.outputs[0].links) { + for (const l of links) { const linkInfo = app.graph.links[l]; const node = this.graph.getNodeById(linkInfo.target_id); const input = node.inputs[linkInfo.target_slot]; @@ -240,7 +255,6 @@ app.registerExtension({ // No widget, we cant connect if (!input.widget) { - if (this.outputs[0]?.type != '*' && target_node.type == "Reroute") return true; if (!(input.type in ComfyWidgets)) return false; } From bfebe2d6c36e2ab7cf9ce9892abe80fc3057c46f Mon Sep 17 00:00:00 2001 From: reaper47 Date: Fri, 9 Jun 2023 13:29:15 +0200 Subject: [PATCH 78/82] Improve ContextMenuFilter extension --- web/extensions/core/contextMenuFilter.js | 226 ++++++++++++----------- web/style.css | 15 +- 2 files changed, 128 insertions(+), 113 deletions(-) diff --git a/web/extensions/core/contextMenuFilter.js b/web/extensions/core/contextMenuFilter.js index 51e66f924..662d87e74 100644 --- a/web/extensions/core/contextMenuFilter.js +++ b/web/extensions/core/contextMenuFilter.js @@ -1,132 +1,138 @@ -import { app } from "/scripts/app.js"; +import {app} from "/scripts/app.js"; // Adds filtering to combo context menus -const id = "Comfy.ContextMenuFilter"; -app.registerExtension({ - name: id, +const ext = { + name: "Comfy.ContextMenuFilter", init() { const ctxMenu = LiteGraph.ContextMenu; + LiteGraph.ContextMenu = function (values, options) { const ctx = ctxMenu.call(this, values, options); // If we are a dark menu (only used for combo boxes) then add a filter input if (options?.className === "dark" && values?.length > 10) { const filter = document.createElement("input"); - Object.assign(filter.style, { - width: "calc(100% - 10px)", - border: "0", - boxSizing: "border-box", - background: "#333", - border: "1px solid #999", - margin: "0 0 5px 5px", - color: "#fff", - }); + filter.classList.add("comfy-context-menu-filter"); filter.placeholder = "Filter list"; this.root.prepend(filter); - let selectedIndex = 0; - let items = this.root.querySelectorAll(".litemenu-entry"); - let itemCount = items.length; - let selectedItem; + const items = Array.from(this.root.querySelectorAll(".litemenu-entry")); + let displayedItems = [...items]; + let itemCount = displayedItems.length; - // Apply highlighting to the selected item - function updateSelected() { - if (selectedItem) { - selectedItem.style.setProperty("background-color", ""); - selectedItem.style.setProperty("color", ""); - } - selectedItem = items[selectedIndex]; - if (selectedItem) { - selectedItem.style.setProperty("background-color", "#ccc", "important"); - selectedItem.style.setProperty("color", "#000", "important"); - } - } + // We must request an animation frame for the current node of the active canvas to update. + requestAnimationFrame(() => { + const currentNode = LGraphCanvas.active_canvas.current_node; + const clickedComboValue = currentNode.widgets + .filter(w => w.type === "combo" && w.options.values.length === values.length) + .find(w => w.options.values.every((v, i) => v === values[i])) + .value; - const positionList = () => { - const rect = this.root.getBoundingClientRect(); - - // If the top is off screen then shift the element with scaling applied - if (rect.top < 0) { - const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; - const shift = (this.root.clientHeight * scale) / 2; - this.root.style.top = -shift + "px"; - } - } - - updateSelected(); - - // Arrow up/down to select items - filter.addEventListener("keydown", (e) => { - if (e.key === "ArrowUp") { - if (selectedIndex === 0) { - selectedIndex = itemCount - 1; - } else { - selectedIndex--; - } - updateSelected(); - e.preventDefault(); - } else if (e.key === "ArrowDown") { - if (selectedIndex === itemCount - 1) { - selectedIndex = 0; - } else { - selectedIndex++; - } - updateSelected(); - e.preventDefault(); - } else if ((selectedItem && e.key === "Enter") || e.keyCode === 13 || e.keyCode === 10) { - selectedItem.click(); - } else if(e.key === "Escape") { - this.close(); - } - }); - - filter.addEventListener("input", () => { - // Hide all items that dont match our filter - const term = filter.value.toLocaleLowerCase(); - items = this.root.querySelectorAll(".litemenu-entry"); - // When filtering recompute which items are visible for arrow up/down - // Try and maintain selection - let visibleItems = []; - for (const item of items) { - const visible = !term || item.textContent.toLocaleLowerCase().includes(term); - if (visible) { - item.style.display = "block"; - if (item === selectedItem) { - selectedIndex = visibleItems.length; - } - visibleItems.push(item); - } else { - item.style.display = "none"; - if (item === selectedItem) { - selectedIndex = 0; - } - } - } - items = visibleItems; + let selectedIndex = values.findIndex(v => v === clickedComboValue); + let selectedItem = displayedItems?.[selectedIndex]; updateSelected(); - // If we have an event then we can try and position the list under the source - if (options.event) { - let top = options.event.clientY - 10; - - const bodyRect = document.body.getBoundingClientRect(); - const rootRect = this.root.getBoundingClientRect(); - if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { - top = Math.max(0, bodyRect.height - rootRect.height - 10); - } - - this.root.style.top = top + "px"; - positionList(); + // Apply highlighting to the selected item + function updateSelected() { + selectedItem?.style.setProperty("background-color", ""); + selectedItem?.style.setProperty("color", ""); + selectedItem = displayedItems[selectedIndex]; + selectedItem?.style.setProperty("background-color", "#ccc", "important"); + selectedItem?.style.setProperty("color", "#000", "important"); } - }); - requestAnimationFrame(() => { - // Focus the filter box when opening - filter.focus(); + const positionList = () => { + const rect = this.root.getBoundingClientRect(); - positionList(); - }); + // If the top is off-screen then shift the element with scaling applied + if (rect.top < 0) { + const scale = 1 - this.root.getBoundingClientRect().height / this.root.clientHeight; + const shift = (this.root.clientHeight * scale) / 2; + this.root.style.top = -shift + "px"; + } + } + + // Arrow up/down to select items + filter.addEventListener("keydown", (event) => { + switch (event.key) { + case "ArrowUp": + event.preventDefault(); + if (selectedIndex === 0) { + selectedIndex = itemCount - 1; + } else { + selectedIndex--; + } + updateSelected(); + break; + case "ArrowRight": + event.preventDefault(); + selectedIndex = itemCount - 1; + updateSelected(); + break; + case "ArrowDown": + event.preventDefault(); + if (selectedIndex === itemCount - 1) { + selectedIndex = 0; + } else { + selectedIndex++; + } + updateSelected(); + break; + case "ArrowLeft": + event.preventDefault(); + selectedIndex = 0; + updateSelected(); + break; + case "Enter": + selectedItem?.click(); + break; + case "Escape": + this.close(); + break; + } + }); + + filter.addEventListener("input", () => { + // Hide all items that don't match our filter + const term = filter.value.toLocaleLowerCase(); + // When filtering, recompute which items are visible for arrow up/down and maintain selection. + displayedItems = items.filter(item => { + const isVisible = !term || item.textContent.toLocaleLowerCase().includes(term); + item.style.display = isVisible ? "block" : "none"; + return isVisible; + }); + + selectedIndex = 0; + if (displayedItems.includes(selectedItem)) { + selectedIndex = displayedItems.findIndex(d => d === selectedItem); + } + itemCount = displayedItems.length; + + updateSelected(); + + // If we have an event then we can try and position the list under the source + if (options.event) { + let top = options.event.clientY - 10; + + const bodyRect = document.body.getBoundingClientRect(); + const rootRect = this.root.getBoundingClientRect(); + if (bodyRect.height && top > bodyRect.height - rootRect.height - 10) { + top = Math.max(0, bodyRect.height - rootRect.height - 10); + } + + this.root.style.top = top + "px"; + positionList(); + } + }); + + requestAnimationFrame(() => { + // Focus the filter box when opening + filter.focus(); + + positionList(); + }); + }) } return ctx; @@ -134,4 +140,6 @@ app.registerExtension({ LiteGraph.ContextMenu.prototype = ctxMenu.prototype; }, -}); +} + +app.registerExtension(ext); diff --git a/web/style.css b/web/style.css index 47571a16e..5fea5bba8 100644 --- a/web/style.css +++ b/web/style.css @@ -50,7 +50,7 @@ body { padding: 30px 30px 10px 30px; background-color: var(--comfy-menu-bg); /* Modal background */ color: var(--error-text); - box-shadow: 0px 0px 20px #888888; + box-shadow: 0 0 20px #888888; border-radius: 10px; top: 50%; left: 50%; @@ -84,7 +84,7 @@ body { font-size: 15px; position: absolute; top: 50%; - right: 0%; + right: 0; text-align: center; z-index: 100; width: 170px; @@ -252,7 +252,7 @@ button.comfy-queue-btn { bottom: 0 !important; left: auto !important; right: 0 !important; - border-radius: 0px; + border-radius: 0; } .comfy-menu span.drag-handle { visibility:hidden @@ -291,7 +291,7 @@ button.comfy-queue-btn { .litegraph .dialog { z-index: 1; - font-family: Arial; + font-family: Arial, sans-serif; } .litegraph .litemenu-entry.has_submenu { @@ -330,6 +330,13 @@ button.comfy-queue-btn { color: var(--input-text) !important; } +.comfy-context-menu-filter { + box-sizing: border-box; + border: 1px solid #999; + margin: 0 0 5px 5px; + width: calc(100% - 10px); +} + /* Search box */ .litegraph.litesearchbox { From de142eaad5818cf4e448d8edc479c89e9b59aff0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 9 Jun 2023 12:24:24 -0400 Subject: [PATCH 79/82] Simpler base model code. --- comfy/diffusers_load.py | 28 ++-------------- comfy/model_base.py | 66 +++++++++++++++++++++++++++++++++++++ comfy/samplers.py | 71 ++++++++++++++++++++++------------------ comfy/sd.py | 72 +++++++++++++++++++++++++++++++---------- 4 files changed, 163 insertions(+), 74 deletions(-) create mode 100644 comfy/model_base.py diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py index 43877fb83..f494f1d30 100644 --- a/comfy/diffusers_load.py +++ b/comfy/diffusers_load.py @@ -4,7 +4,7 @@ import yaml import folder_paths from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE, load_checkpoint import os.path as osp import re import torch @@ -84,28 +84,4 @@ def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, emb # Put together new checkpoint sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae + return load_checkpoint(embedding_directory=embedding_directory, state_dict=sd, config=config) diff --git a/comfy/model_base.py b/comfy/model_base.py new file mode 100644 index 000000000..7370c19fd --- /dev/null +++ b/comfy/model_base.py @@ -0,0 +1,66 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel +from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from comfy.ldm.modules.diffusionmodules.util import make_beta_schedule +import numpy as np + +class BaseModel(torch.nn.Module): + def __init__(self, unet_config, v_prediction=False): + super().__init__() + + self.register_schedule(given_betas=None, beta_schedule="linear", timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3) + self.diffusion_model = UNetModel(**unet_config) + self.v_prediction = v_prediction + if self.v_prediction: + self.parameterization = "v" + else: + self.parameterization = "eps" + if "adm_in_channels" in unet_config: + self.adm_channels = unet_config["adm_in_channels"] + else: + self.adm_channels = 0 + print("v_prediction", v_prediction) + print("adm", self.adm_channels) + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if given_betas is not None: + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + + self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32)) + self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32)) + self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) + + def apply_model(self, x, t, c_concat=None, c_crossattn=None, c_adm=None, control=None, transformer_options={}): + if c_concat is not None: + xc = torch.cat([x] + c_concat, dim=1) + else: + xc = x + context = torch.cat(c_crossattn, 1) + return self.diffusion_model(xc, t, context=context, y=c_adm, control=control, transformer_options=transformer_options) + + def get_dtype(self): + return self.diffusion_model.dtype + + def is_adm(self): + return self.adm_channels > 0 + +class SD21UNCLIP(BaseModel): + def __init__(self, unet_config, noise_aug_config, v_prediction=True): + super().__init__(unet_config, v_prediction) + self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) + +class SDInpaint(BaseModel): + def __init__(self, unet_config, v_prediction=False): + super().__init__(unet_config, v_prediction) + self.concat_keys = ("mask", "masked_image") diff --git a/comfy/samplers.py b/comfy/samplers.py index 1fb928f8d..a33d150d0 100644 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -248,7 +248,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con c['transformer_options'] = transformer_options - output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks) + output = model_function(input_x, timestep_, **c).chunk(batch_chunks) del input_x model_management.throw_exception_if_processing_interrupted() @@ -460,36 +460,42 @@ def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): uncond[temp[1]] = [o[0], n] -def encode_adm(noise_augmentor, conds, batch_size, device): +def encode_adm(conds, batch_size, device, noise_augmentor=None): for t in range(len(conds)): x = conds[t] - if 'adm' in x[1]: - adm_inputs = [] - weights = [] - noise_aug = [] - adm_in = x[1]["adm"] - for adm_c in adm_in: - adm_cond = adm_c[0].image_embeds - weight = adm_c[1] - noise_augment = adm_c[2] - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) + adm_out = None + if noise_augmentor is not None: + if 'adm' in x[1]: + adm_inputs = [] + weights = [] + noise_aug = [] + adm_in = x[1]["adm"] + for adm_c in adm_in: + adm_cond = adm_c[0].image_embeds + weight = adm_c[1] + noise_augment = adm_c[2] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + #TODO: add a way to control this + noise_augment = 0.05 + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + else: + adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) else: - adm_out = torch.zeros((1, noise_augmentor.time_embed.dim * 2), device=device) - x[1] = x[1].copy() - x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) + if 'adm' in x[1]: + adm_out = x[1]["adm"].to(device) + if adm_out is not None: + x[1] = x[1].copy() + x[1]["adm_encoded"] = torch.cat([adm_out] * batch_size) return conds @@ -591,14 +597,17 @@ class KSampler: apply_empty_x_to_equal_area(positive, negative, 'control', lambda cond_cnets, x: cond_cnets[x]) apply_empty_x_to_equal_area(positive, negative, 'gligen', lambda cond_cnets, x: cond_cnets[x]) - if self.model.model.diffusion_model.dtype == torch.float16: + if self.model.get_dtype() == torch.float16: precision_scope = torch.autocast else: precision_scope = contextlib.nullcontext - if hasattr(self.model, 'noise_augmentor'): #unclip - positive = encode_adm(self.model.noise_augmentor, positive, noise.shape[0], self.device) - negative = encode_adm(self.model.noise_augmentor, negative, noise.shape[0], self.device) + if self.model.is_adm(): + noise_augmentor = None + if hasattr(self.model, 'noise_augmentor'): #unclip + noise_augmentor = self.model.noise_augmentor + positive = encode_adm(positive, noise.shape[0], self.device, noise_augmentor) + negative = encode_adm(negative, noise.shape[0], self.device, noise_augmentor) extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options} diff --git a/comfy/sd.py b/comfy/sd.py index 04eaaa9fe..3747f53b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -15,8 +15,15 @@ from . import utils from . import clip_vision from . import gligen from . import diffusers_convert +from . import model_base def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): + replace_prefix = {"model.diffusion_model.": "diffusion_model."} + for rp in replace_prefix: + replace = list(map(lambda a: (a, "{}{}".format(replace_prefix[rp], a[len(rp):])), filter(lambda a: a.startswith(rp), sd.keys()))) + for x in replace: + sd[x[1]] = sd.pop(x[0]) + m, u = model.load_state_dict(sd, strict=False) k = list(sd.keys()) @@ -182,7 +189,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.1".format(b) + tk = "diffusion_model.input_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -193,13 +200,13 @@ def model_lora_keys(model, key_map={}): if up_counter >= 4: counter += 1 for c in LORA_UNET_MAP_ATTENTIONS: - k = "model.diffusion_model.middle_block.1.{}.weight".format(c) + k = "diffusion_model.middle_block.1.{}.weight".format(c) if k in sdk: lora_key = "lora_unet_mid_block_attentions_0_{}".format(LORA_UNET_MAP_ATTENTIONS[c]) key_map[lora_key] = k counter = 3 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.1".format(b) + tk = "diffusion_model.output_blocks.{}.1".format(b) up_counter = 0 for c in LORA_UNET_MAP_ATTENTIONS: k = "{}.{}.weight".format(tk, c) @@ -223,7 +230,7 @@ def model_lora_keys(model, key_map={}): ds_counter = 0 counter = 0 for b in range(12): - tk = "model.diffusion_model.input_blocks.{}.0".format(b) + tk = "diffusion_model.input_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -242,7 +249,7 @@ def model_lora_keys(model, key_map={}): counter = 0 for b in range(3): - tk = "model.diffusion_model.middle_block.{}".format(b) + tk = "diffusion_model.middle_block.{}".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -256,7 +263,7 @@ def model_lora_keys(model, key_map={}): counter = 0 us_counter = 0 for b in range(12): - tk = "model.diffusion_model.output_blocks.{}.0".format(b) + tk = "diffusion_model.output_blocks.{}.0".format(b) key_in = False for c in LORA_UNET_MAP_RESNET: k = "{}.{}.weight".format(tk, c) @@ -332,7 +339,7 @@ class ModelPatcher: patch_list[i] = patch_list[i].to(device) def model_dtype(self): - return self.model.diffusion_model.dtype + return self.model.get_dtype() def add_patches(self, patches, strength=1.0): p = {} @@ -764,7 +771,7 @@ def load_controlnet(ckpt_path, model=None): for x in controlnet_data: c_m = "control_model." if x.startswith(c_m): - sd_key = "model.diffusion_model.{}".format(x[len(c_m):]) + sd_key = "diffusion_model.{}".format(x[len(c_m):]) if sd_key in model_sd: cd = controlnet_data[x] cd += model_sd[sd_key].type(cd.dtype).to(cd.device) @@ -931,9 +938,10 @@ def load_gligen(ckpt_path): model = model.half() return model -def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=None): - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) +def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_clip=True, embedding_directory=None, state_dict=None, config=None): + if config is None: + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) model_config_params = config['model']['params'] clip_config = model_config_params['cond_stage_config'] scale_factor = model_config_params['scale_factor'] @@ -942,8 +950,19 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e fp16 = False if "unet_config" in model_config_params: if "params" in model_config_params["unet_config"]: - if "use_fp16" in model_config_params["unet_config"]["params"]: - fp16 = model_config_params["unet_config"]["params"]["use_fp16"] + unet_config = model_config_params["unet_config"]["params"] + if "use_fp16" in unet_config: + fp16 = unet_config["use_fp16"] + + noise_aug_config = None + if "noise_aug_config" in model_config_params: + noise_aug_config = model_config_params["noise_aug_config"] + + v_prediction = False + + if "parameterization" in model_config_params: + if model_config_params["parameterization"] == "v": + v_prediction = True clip = None vae = None @@ -963,9 +982,16 @@ def load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, e w.cond_stage_model = clip.cond_stage_model load_state_dict_to = [w] - model = instantiate_from_config(config["model"]) - sd = utils.load_torch_file(ckpt_path) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + if config['model']["target"].endswith("LatentInpaintDiffusion"): + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif config['model']["target"].endswith("ImageEmbeddingConditionedLatentDiffusion"): + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + + if state_dict is None: + state_dict = utils.load_torch_file(ckpt_path) + model = load_model_weights(model, state_dict, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: model = model.half() @@ -1073,16 +1099,20 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o sd_config["unet_config"] = {"target": "comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel", "params": unet_config} model_config = {"target": "comfy.ldm.models.diffusion.ddpm.LatentDiffusion", "params": sd_config} + unclip_model = False + inpaint_model = False if noise_aug_config is not None: #SD2.x unclip model sd_config["noise_aug_config"] = noise_aug_config sd_config["image_size"] = 96 sd_config["embedding_dropout"] = 0.25 sd_config["conditioning_key"] = 'crossattn-adm' + unclip_model = True model_config["target"] = "comfy.ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion" elif unet_config["in_channels"] > 4: #inpainting model sd_config["conditioning_key"] = "hybrid" sd_config["finetune_keys"] = None model_config["target"] = "comfy.ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + inpaint_model = True else: sd_config["conditioning_key"] = "crossattn" @@ -1096,13 +1126,21 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o unet_config["num_classes"] = "sequential" unet_config["adm_in_channels"] = sd[unclip].shape[1] + v_prediction = False if unet_config["context_dim"] == 1024 and unet_config["in_channels"] == 4: #only SD2.x non inpainting models are v prediction k = "model.diffusion_model.output_blocks.11.1.transformer_blocks.0.norm1.bias" out = sd[k] if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out. + v_prediction = True sd_config["parameterization"] = 'v' - model = instantiate_from_config(model_config) + if inpaint_model: + model = model_base.SDInpaint(unet_config, v_prediction=v_prediction) + elif unclip_model: + model = model_base.SD21UNCLIP(unet_config, noise_aug_config["params"], v_prediction=v_prediction) + else: + model = model_base.BaseModel(unet_config, v_prediction=v_prediction) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) if fp16: From 2bcdd6c7d401d2e036325c6f3d16aa7eff7ccf14 Mon Sep 17 00:00:00 2001 From: Jorge Campo <62282406+jorge-campo@users.noreply.github.com> Date: Fri, 9 Jun 2023 22:25:33 +0200 Subject: [PATCH 80/82] Add install instructions for Apple silicon --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d9083b7e1..001641496 100644 --- a/README.md +++ b/README.md @@ -121,9 +121,16 @@ After this you should have everything installed and can proceed to running Comfy [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) -Mac/MPS: There is basic support in the code but until someone makes some install instruction you are on your own. +You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. -Directml: ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` +According to the [DirectML page](https://github.com/microsoft/DirectML#hardware-requirements), `pytorch-directml` package is not avilable for Apple silicon computers. However, you can still run ComfyUI without `pytorch-directml`. + +1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. +1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. +1. Install the ComfyUI [dependencies](#dependencies). If you have another UI to work with Stable Diffusion (such as Automatic1111), you can use the the packages for this installation. See [the instruction below](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). +1. Launch ComfyUI by running `python main.py`. + +> **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? From ba23753670913a6f38a22f84cfb631e44f549a78 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Jun 2023 03:23:01 -0400 Subject: [PATCH 81/82] DirectML is for Windows. --- README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 001641496..e33060307 100644 --- a/README.md +++ b/README.md @@ -123,8 +123,6 @@ After this you should have everything installed and can proceed to running Comfy You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. -According to the [DirectML page](https://github.com/microsoft/DirectML#hardware-requirements), `pytorch-directml` package is not avilable for Apple silicon computers. However, you can still run ComfyUI without `pytorch-directml`. - 1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. 1. Install the ComfyUI [dependencies](#dependencies). If you have another UI to work with Stable Diffusion (such as Automatic1111), you can use the the packages for this installation. See [the instruction below](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). @@ -132,6 +130,7 @@ According to the [DirectML page](https://github.com/microsoft/DirectML#hardware- > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). +DirectML (AMD Cards on Windows): ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies? From 656f62569d4e3858d7f99ec7a44f23e885353285 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sat, 10 Jun 2023 04:19:33 -0400 Subject: [PATCH 82/82] Make the sections in the others install section more clearly separate. --- README.md | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index e33060307..1de9d4c3b 100644 --- a/README.md +++ b/README.md @@ -119,18 +119,22 @@ After this you should have everything installed and can proceed to running Comfy ### Others: -[Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) +#### [Intel Arc](https://github.com/comfyanonymous/ComfyUI/discussions/476) + +#### Apple Mac silicon You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version. 1. Install pytorch. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide. 1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux. -1. Install the ComfyUI [dependencies](#dependencies). If you have another UI to work with Stable Diffusion (such as Automatic1111), you can use the the packages for this installation. See [the instruction below](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). +1. Install the ComfyUI [dependencies](#dependencies). If you have another Stable Diffusion UI [you might be able to reuse the dependencies](#i-already-have-another-ui-for-stable-diffusion-installed-do-i-really-have-to-install-all-of-these-dependencies). 1. Launch ComfyUI by running `python main.py`. > **Note**: Remember to add your models, VAE, LoRAs etc. to the corresponding Comfy folders, as discussed in [ComfyUI manual installation](#manual-install-windows-linux). -DirectML (AMD Cards on Windows): ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` +#### DirectML (AMD Cards on Windows) + +```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?