From ffc56c53c9cccfcc21c92fe14cb095bb32ea2744 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:22:38 -0400 Subject: [PATCH 01/48] Add a node_errors to the /prompt error json response. "node_errors" contains a dict keyed by node ids. The contents are a message and a list of dependent outputs. --- execution.py | 27 ++++++++++++++++----------- server.py | 4 ++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index 35f044346..212e789ca 100644 --- a/execution.py +++ b/execution.py @@ -299,18 +299,18 @@ def validate_inputs(prompt, item, validated): required_inputs = class_inputs['required'] for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x)) + return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x)) + return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input)) + return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) r = validate_inputs(prompt, o_id, validated) if r[0] == False: validated[o_id] = r @@ -328,9 +328,9 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x)) + return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x)) + return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) @@ -338,13 +338,13 @@ def validate_inputs(prompt, item, validated): ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") for r in ret: if r != True: - return (False, "{}, {}".format(class_type, r)) + return (False, "{}, {}".format(class_type, r), unique_id) else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) - ret = (True, "") + ret = (True, "", unique_id) validated[unique_id] = ret return ret @@ -356,10 +356,11 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs") + return (False, "Prompt has no outputs", [], []) good_outputs = set() errors = [] + node_errors = {} validated = {} for o in outputs: valid = False @@ -368,6 +369,7 @@ def validate_prompt(prompt): m = validate_inputs(prompt, o, validated) valid = m[0] reason = m[1] + node_id = m[2] except Exception as e: print(traceback.format_exc()) valid = False @@ -379,12 +381,15 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list)) + return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) - return (True, "", list(good_outputs)) + return (True, "", list(good_outputs), node_errors) class PromptQueue: diff --git a/server.py b/server.py index 701c0e7a7..8429a63fb 100644 --- a/server.py +++ b/server.py @@ -336,9 +336,9 @@ class PromptServer(): return web.json_response({"prompt_id": prompt_id}) else: print("invalid prompt:", valid[1]) - return web.json_response({"error": valid[1]}, status=400) + return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: - return web.json_response({"error": "no prompt"}, status=400) + return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @routes.post("/queue") async def post_queue(request): From db27b0405a31983916d6801cf84f7f1fc4503e6a Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 13:25:50 -0400 Subject: [PATCH 02/48] object_info now returns if node is an output_node or not. --- server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/server.py b/server.py index 8429a63fb..c0f79cbd5 100644 --- a/server.py +++ b/server.py @@ -272,6 +272,11 @@ class PromptServer(): info['display_name'] = nodes.NODE_DISPLAY_NAME_MAPPINGS[node_class] if node_class in nodes.NODE_DISPLAY_NAME_MAPPINGS.keys() else node_class info['description'] = '' info['category'] = 'sd' + if hasattr(obj_class, 'OUTPUT_NODE') and obj_class.OUTPUT_NODE == True: + info['output_node'] = True + else: + info['output_node'] = False + if hasattr(obj_class, 'CATEGORY'): info['category'] = obj_class.CATEGORY return info From bfb13f5eee48545f1c4b0b8a377de80be84bb100 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 17:05:23 -0400 Subject: [PATCH 03/48] Remove useless call to /object_info --- web/extensions/core/colorPalette.js | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/web/extensions/core/colorPalette.js b/web/extensions/core/colorPalette.js index 2f2238a2b..bfcd847a3 100644 --- a/web/extensions/core/colorPalette.js +++ b/web/extensions/core/colorPalette.js @@ -174,7 +174,7 @@ const els = {} // const ctxMenu = LiteGraph.ContextMenu; app.registerExtension({ name: id, - init() { + addCustomNodeDefs(node_defs) { const sortObjectKeys = (unordered) => { return Object.keys(unordered).sort().reduce((obj, key) => { obj[key] = unordered[key]; @@ -182,10 +182,10 @@ app.registerExtension({ }, {}); }; - const getSlotTypes = async () => { + function getSlotTypes() { var types = []; - const defs = await api.getNodeDefs(); + const defs = node_defs; for (const nodeId in defs) { const nodeData = defs[nodeId]; @@ -212,8 +212,8 @@ app.registerExtension({ return types; }; - const completeColorPalette = async (colorPalette) => { - var types = await getSlotTypes(); + function completeColorPalette(colorPalette) { + var types = getSlotTypes(); for (const type of types) { if (!colorPalette.colors.node_slot[type]) { From 48fcc5b777b3a1ab5d6dc5fec6adaebeb32c2c93 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 22 May 2023 20:51:30 -0400 Subject: [PATCH 04/48] Parsing error crash. --- execution.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 212e789ca..25f2fcacd 100644 --- a/execution.py +++ b/execution.py @@ -374,6 +374,7 @@ def validate_prompt(prompt): print(traceback.format_exc()) valid = False reason = "Parsing error" + node_id = None if valid == True: good_outputs.add(o) @@ -381,9 +382,10 @@ def validate_prompt(prompt): print("Failed to validate prompt for output {} {}".format(o, reason)) print("output will be ignored") errors += [(o, reason)] - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + if node_id is not None: + if node_id not in node_errors: + node_errors[node_id] = {"message": reason, "dependent_outputs": []} + node_errors[node_id]["dependent_outputs"].append(o) if len(good_outputs) == 0: errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) From 34887b888546716b5c5507606289ca2728bf3123 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 03:12:56 -0400 Subject: [PATCH 05/48] Add experimental bislerp algorithm for latent upscaling. It's like bilinear but with slerp. --- comfy/utils.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++- nodes.py | 2 +- 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 09e05d4ed..0f7b34503 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -46,6 +46,65 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +#slow and inefficient, should be optimized +def bislerp(samples, width, height): + shape = list(samples.shape) + width_scale = (shape[3]) / (width ) + height_scale = (shape[2]) / (height ) + + shape[3] = width + shape[2] = height + out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + + def algorithm(in1, w1, in2, w2): + dims = in1.shape + val = w2 + + #flatten to batches + low = in1.reshape(dims[0], -1) + high = in2.reshape(dims[0], -1) + + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + + # in case we divide by zero + low_norm[low_norm != low_norm] = 0.0 + high_norm[high_norm != high_norm] = 0.0 + + omega = torch.acos((low_norm*high_norm).sum(1)) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res.reshape(dims) + + for x_dest in range(shape[3]): + for y_dest in range(shape[2]): + y = (y_dest) * height_scale + x = (x_dest) * width_scale + + x1 = max(math.floor(x), 0) + x2 = min(x1 + 1, samples.shape[3] - 1) + y1 = max(math.floor(y), 0) + y2 = min(y1 + 1, samples.shape[2] - 1) + + in1 = samples[:,:,y1,x1] + in2 = samples[:,:,y1,x2] + in3 = samples[:,:,y2,x1] + in4 = samples[:,:,y2,x2] + + if (x1 == x2) and (y1 == y2): + out_value = in1 + elif (x1 == x2): + out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + elif (y1 == y2): + out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + else: + o1 = algorithm(in1, (x2 - x), in2, (x - x1)) + o2 = algorithm(in3, (x2 - x), in4, (x - x1)) + out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + + out1[:,:,y_dest,x_dest] = out_value + return out1 + def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": old_width = samples.shape[3] @@ -61,7 +120,11 @@ def common_upscale(samples, width, height, upscale_method, crop): s = samples[:,:,y:old_height-y,x:old_width-x] else: s = samples - return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) + + if upscale_method == "bislerp": + return bislerp(s, width, height) + else: + return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method) def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) diff --git a/nodes.py b/nodes.py index bae330bc9..e5cec2632 100644 --- a/nodes.py +++ b/nodes.py @@ -749,7 +749,7 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: - upscale_methods = ["nearest-exact", "bilinear", "area"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] crop_methods = ["disabled", "center"] @classmethod From 451fb4169ad900e5d33b540f039f56ced9a76157 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:35:32 -0400 Subject: [PATCH 06/48] Fix 'git pull' not working on the standalones. --- .github/workflows/windows_release_cu118_package.yml | 1 + .github/workflows/windows_release_nightly_pytorch.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/windows_release_cu118_package.yml b/.github/workflows/windows_release_cu118_package.yml index 15322c86a..2d6048a23 100644 --- a/.github/workflows/windows_release_cu118_package.yml +++ b/.github/workflows/windows_release_cu118_package.yml @@ -30,6 +30,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - shell: bash run: | cd .. diff --git a/.github/workflows/windows_release_nightly_pytorch.yml b/.github/workflows/windows_release_nightly_pytorch.yml index b6a18ec0a..767a7216b 100644 --- a/.github/workflows/windows_release_nightly_pytorch.yml +++ b/.github/workflows/windows_release_nightly_pytorch.yml @@ -17,6 +17,7 @@ jobs: - uses: actions/checkout@v3 with: fetch-depth: 0 + persist-credentials: false - uses: actions/setup-python@v4 with: python-version: '3.11.3' From b8ccbec6d893d34dab90d2418a3fe00969251fa8 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 11:40:24 -0400 Subject: [PATCH 07/48] Various improvements to bislerp. --- comfy/utils.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 0f7b34503..300eda6aa 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -56,35 +56,42 @@ def bislerp(samples, width, height): shape[2] = height out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) - def algorithm(in1, w1, in2, w2): + def algorithm(in1, in2, t): dims = in1.shape - val = w2 + val = t #flatten to batches low = in1.reshape(dims[0], -1) high = in2.reshape(dims[0], -1) - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) + low_weight = torch.norm(low, dim=1, keepdim=True) + low_weight[low_weight == 0] = 0.0000000001 + low_norm = low/low_weight + high_weight = torch.norm(high, dim=1, keepdim=True) + high_weight[high_weight == 0] = 0.0000000001 + high_norm = high/high_weight - # in case we divide by zero - low_norm[low_norm != low_norm] = 0.0 - high_norm[high_norm != high_norm] = 0.0 - - omega = torch.acos((low_norm*high_norm).sum(1)) + dot_prod = (low_norm*high_norm).sum(1) + dot_prod[dot_prod > 0.9995] = 0.9995 + dot_prod[dot_prod < -0.9995] = -0.9995 + omega = torch.acos(dot_prod) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm + res *= (low_weight * (1.0-val) + high_weight * val) return res.reshape(dims) for x_dest in range(shape[3]): for y_dest in range(shape[2]): - y = (y_dest) * height_scale - x = (x_dest) * width_scale + y = (y_dest + 0.5) * height_scale - 0.5 + x = (x_dest + 0.5) * width_scale - 0.5 x1 = max(math.floor(x), 0) x2 = min(x1 + 1, samples.shape[3] - 1) + wx = x - math.floor(x) + y1 = max(math.floor(y), 0) y2 = min(y1 + 1, samples.shape[2] - 1) + wy = y - math.floor(y) in1 = samples[:,:,y1,x1] in2 = samples[:,:,y1,x2] @@ -94,13 +101,13 @@ def bislerp(samples, width, height): if (x1 == x2) and (y1 == y2): out_value = in1 elif (x1 == x2): - out_value = algorithm(in1, (y2 - y), in3, (y - y1)) + out_value = algorithm(in1, in3, wy) elif (y1 == y2): - out_value = algorithm(in1, (x2 - x), in2, (x - x1)) + out_value = algorithm(in1, in2, wx) else: - o1 = algorithm(in1, (x2 - x), in2, (x - x1)) - o2 = algorithm(in3, (x2 - x), in4, (x - x1)) - out_value = algorithm(o1, (y2 - y), o2, (y - y1)) + o1 = algorithm(in1, in2, wx) + o2 = algorithm(in3, in4, wx) + out_value = algorithm(o1, o2, wy) out1[:,:,y_dest,x_dest] = out_value return out1 From c00bb1a0b78f0d2cf2e4ec2dd9ae7d61cb07a637 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 12:53:38 -0400 Subject: [PATCH 08/48] Add a latent upscale by node. --- nodes.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nodes.py b/nodes.py index e5cec2632..f0a93ebd5 100644 --- a/nodes.py +++ b/nodes.py @@ -768,6 +768,25 @@ class LatentUpscale: s["samples"] = comfy.utils.common_upscale(samples["samples"], width // 8, height // 8, upscale_method, crop) return (s,) +class LatentUpscaleBy: + upscale_methods = ["nearest-exact", "bilinear", "area", "bislerp"] + + @classmethod + def INPUT_TYPES(s): + return {"required": { "samples": ("LATENT",), "upscale_method": (s.upscale_methods,), + "scale_by": ("FLOAT", {"default": 1.5, "min": 0.01, "max": 8.0, "step": 0.01}),}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, samples, upscale_method, scale_by): + s = samples.copy() + width = round(samples["samples"].shape[3] * scale_by) + height = round(samples["samples"].shape[2] * scale_by) + s["samples"] = comfy.utils.common_upscale(samples["samples"], width, height, upscale_method, "disabled") + return (s,) + class LatentRotate: @classmethod def INPUT_TYPES(s): @@ -1244,6 +1263,7 @@ NODE_CLASS_MAPPINGS = { "VAELoader": VAELoader, "EmptyLatentImage": EmptyLatentImage, "LatentUpscale": LatentUpscale, + "LatentUpscaleBy": LatentUpscaleBy, "LatentFromBatch": LatentFromBatch, "RepeatLatentBatch": RepeatLatentBatch, "SaveImage": SaveImage, @@ -1322,6 +1342,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "LatentCrop": "Crop Latent", "EmptyLatentImage": "Empty Latent Image", "LatentUpscale": "Upscale Latent", + "LatentUpscaleBy": "Upscale Latent By", "LatentComposite": "Latent Composite", "LatentFromBatch" : "Latent From Batch", "RepeatLatentBatch": "Repeat Latent Batch", From 7310290f17aad79480edb92f22cd58f0997db964 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 23 May 2023 22:26:50 -0400 Subject: [PATCH 09/48] Pull in latest upscale model code from chainner. --- .../architecture/OmniSR/ChannelAttention.py | 110 ++++ .../architecture/OmniSR/LICENSE | 201 ++++++ .../architecture/OmniSR/OSA.py | 577 ++++++++++++++++++ .../architecture/OmniSR/OSAG.py | 60 ++ .../architecture/OmniSR/OmniSR.py | 133 ++++ .../architecture/OmniSR/esa.py | 294 +++++++++ .../architecture/OmniSR/layernorm.py | 70 +++ .../architecture/OmniSR/pixelshuffle.py | 31 + .../chainner_models/architecture/RRDB.py | 17 +- .../chainner_models/architecture/block.py | 30 + comfy_extras/chainner_models/model_loading.py | 5 + comfy_extras/chainner_models/types.py | 4 +- 12 files changed, 1530 insertions(+), 2 deletions(-) create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/LICENSE create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSA.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OSAG.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/esa.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/layernorm.py create mode 100644 comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py diff --git a/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py new file mode 100644 index 000000000..f4d52aa1e --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py @@ -0,0 +1,110 @@ +import math + +import torch.nn as nn + + +class CA_layer(nn.Module): + def __init__(self, channel, reduction=16): + super(CA_layer, self).__init__() + # global average pooling + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False), + nn.GELU(), + nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False), + # nn.Sigmoid() + ) + + def forward(self, x): + y = self.fc(self.gap(x)) + return x * y.expand_as(x) + + +class Simple_CA_layer(nn.Module): + def __init__(self, channel): + super(Simple_CA_layer, self).__init__() + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=channel, + out_channels=channel, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + return x * self.fc(self.gap(x)) + + +class ECA_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) + + +class ECA_MaxPool_layer(nn.Module): + """Constructs a ECA module. + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + + def __init__(self, channel): + super(ECA_MaxPool_layer, self).__init__() + + b = 1 + gamma = 2 + k_size = int(abs(math.log(channel, 2) + b) / gamma) + k_size = k_size if k_size % 2 else k_size + 1 + self.max_pool = nn.AdaptiveMaxPool2d(1) + self.conv = nn.Conv1d( + 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False + ) + # self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + # b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.max_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + # y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/LICENSE b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSA.py b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py new file mode 100644 index 000000000..d7a129696 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSA.py @@ -0,0 +1,577 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSA.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:07:42 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from einops.layers.torch import Rearrange, Reduce +from torch import einsum, nn + +from .layernorm import LayerNorm2d + +# helpers + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def cast_tuple(val, length=1): + return val if isinstance(val, tuple) else ((val,) * length) + + +# helper classes + + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class Conv_PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = LayerNorm2d(dim) + self.fn = fn + + def forward(self, x): + return self.fn(self.norm(x)) + x + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=2, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Conv2d(dim, inner_dim, 1, 1, 0), + nn.GELU(), + nn.Dropout(dropout), + nn.Conv2d(inner_dim, dim, 1, 1, 0), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Gated_Conv_FeedForward(nn.Module): + def __init__(self, dim, mult=1, bias=False, dropout=0.0): + super().__init__() + + hidden_features = int(dim * mult) + + self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) + + self.dwconv = nn.Conv2d( + hidden_features * 2, + hidden_features * 2, + kernel_size=3, + stride=1, + padding=1, + groups=hidden_features * 2, + bias=bias, + ) + + self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) + + def forward(self, x): + x = self.project_in(x) + x1, x2 = self.dwconv(x).chunk(2, dim=1) + x = F.gelu(x1) * x2 + x = self.project_out(x) + return x + + +# MBConv + + +class SqueezeExcitation(nn.Module): + def __init__(self, dim, shrinkage_rate=0.25): + super().__init__() + hidden_dim = int(dim * shrinkage_rate) + + self.gate = nn.Sequential( + Reduce("b c h w -> b c", "mean"), + nn.Linear(dim, hidden_dim, bias=False), + nn.SiLU(), + nn.Linear(hidden_dim, dim, bias=False), + nn.Sigmoid(), + Rearrange("b c -> b c 1 1"), + ) + + def forward(self, x): + return x * self.gate(x) + + +class MBConvResidual(nn.Module): + def __init__(self, fn, dropout=0.0): + super().__init__() + self.fn = fn + self.dropsample = Dropsample(dropout) + + def forward(self, x): + out = self.fn(x) + out = self.dropsample(out) + return out + x + + +class Dropsample(nn.Module): + def __init__(self, prob=0): + super().__init__() + self.prob = prob + + def forward(self, x): + device = x.device + + if self.prob == 0.0 or (not self.training): + return x + + keep_mask = ( + torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_() + > self.prob + ) + return x * keep_mask / (1 - self.prob) + + +def MBConv( + dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0 +): + hidden_dim = int(expansion_rate * dim_out) + stride = 2 if downsample else 1 + + net = nn.Sequential( + nn.Conv2d(dim_in, hidden_dim, 1), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + nn.Conv2d( + hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim + ), + # nn.BatchNorm2d(hidden_dim), + nn.GELU(), + SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate), + nn.Conv2d(hidden_dim, dim_out, 1), + # nn.BatchNorm2d(dim_out) + ) + + if dim_in == dim_out and not downsample: + net = MBConvResidual(net, dropout=dropout) + + return net + + +# attention related classes +class Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Sequential( + nn.Linear(dim, dim, bias=False), nn.Dropout(dropout) + ) + + # relative positional bias + if self.with_pe: + self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) + + pos = torch.arange(window_size) + grid = torch.stack(torch.meshgrid(pos, pos)) + grid = rearrange(grid, "c i j -> (i j) c") + rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange( + grid, "j ... -> 1 j ..." + ) + rel_pos += window_size - 1 + rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum( + dim=-1 + ) + + self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False) + + def forward(self, x): + batch, height, width, window_height, window_width, _, device, h = ( + *x.shape, + x.device, + self.heads, + ) + + # flatten + + x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d") + + # project for queries, keys, values + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + # split heads + + q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v)) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # add positional bias + if self.with_pe: + bias = self.rel_pos_bias(self.rel_pos_indices) + sim = sim + rearrange(bias, "i j h -> h i j") + + # attention + + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + + out = rearrange( + out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width + ) + + # combine heads out + + out = self.to_out(out) + return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width) + + +class Block_Attention(nn.Module): + def __init__( + self, + dim, + dim_head=32, + bias=False, + dropout=0.0, + window_size=7, + with_pe=True, + ): + super().__init__() + assert ( + dim % dim_head + ) == 0, "dimension should be divisible by dimension per head" + + self.heads = dim // dim_head + self.ps = window_size + self.scale = dim_head**-0.5 + self.with_pe = with_pe + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + + self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout)) + + self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + # project for queries, keys, values + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + q, k, v = qkv.chunk(3, dim=1) + + # split heads + + q, k, v = map( + lambda t: rearrange( + t, + "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d", + h=self.heads, + w1=self.ps, + w2=self.ps, + ), + (q, k, v), + ) + + # scale + + q = q * self.scale + + # sim + + sim = einsum("b h i d, b h j d -> b h i j", q, k) + + # attention + attn = self.attend(sim) + + # aggregate + + out = einsum("b h i j, b h j d -> b h i d", attn, v) + + # merge heads + out = rearrange( + out, + "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)", + x=h // self.ps, + y=w // self.ps, + head=self.heads, + w1=self.ps, + w2=self.ps, + ) + + out = self.to_out(out) + return out + + +class Channel_Attention(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class Channel_Attention_grid(nn.Module): + def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7): + super(Channel_Attention_grid, self).__init__() + self.heads = heads + + self.temperature = nn.Parameter(torch.ones(heads, 1, 1)) + + self.ps = window_size + + self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias) + self.qkv_dwconv = nn.Conv2d( + dim * 3, + dim * 3, + kernel_size=3, + stride=1, + padding=1, + groups=dim * 3, + bias=bias, + ) + self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv_dwconv(self.qkv(x)) + qkv = qkv.chunk(3, dim=1) + + q, k, v = map( + lambda t: rearrange( + t, + "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)", + ph=self.ps, + pw=self.ps, + head=self.heads, + ), + qkv, + ) + + q = F.normalize(q, dim=-1) + k = F.normalize(k, dim=-1) + + attn = (q @ k.transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + out = attn @ v + + out = rearrange( + out, + "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)", + h=h // self.ps, + w=w // self.ps, + ph=self.ps, + pw=self.ps, + head=self.heads, + ) + + out = self.project_out(out) + + return out + + +class OSA_Block(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + ffn_bias=True, + window_size=8, + with_pe=False, + dropout=0.0, + ): + super(OSA_Block, self).__init__() + + w = window_size + + self.layer = nn.Sequential( + MBConv( + channel_num, + channel_num, + downsample=False, + expansion_rate=1, + shrinkage_rate=0.25, + ), + Rearrange( + "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w + ), # block-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + Rearrange( + "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w + ), # grid-like attention + PreNormResidual( + channel_num, + Attention( + dim=channel_num, + dim_head=channel_num // 4, + dropout=dropout, + window_size=window_size, + with_pe=with_pe, + ), + ), + Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + # channel-like attention + Conv_PreNormResidual( + channel_num, + Channel_Attention_grid( + dim=channel_num, heads=4, dropout=dropout, window_size=window_size + ), + ), + Conv_PreNormResidual( + channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout) + ), + ) + + def forward(self, x): + out = self.layer(x) + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py new file mode 100644 index 000000000..477e81f9d --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OSAG.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:08:49 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + + +import torch.nn as nn + +from .esa import ESA +from .OSA import OSA_Block + + +class OSAG(nn.Module): + def __init__( + self, + channel_num=64, + bias=True, + block_num=4, + ffn_bias=False, + window_size=0, + pe=False, + ): + super(OSAG, self).__init__() + + # print("window_size: %d" % (window_size)) + # print("with_pe", pe) + # print("ffn_bias: %d" % (ffn_bias)) + + # block_script_name = kwargs.get("block_script_name", "OSA") + # block_class_name = kwargs.get("block_class_name", "OSA_Block") + + # script_name = "." + block_script_name + # package = __import__(script_name, fromlist=True) + block_class = OSA_Block # getattr(package, block_class_name) + group_list = [] + for _ in range(block_num): + temp_res = block_class( + channel_num, + bias, + ffn_bias=ffn_bias, + window_size=window_size, + with_pe=pe, + ) + group_list.append(temp_res) + group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias)) + self.residual_layer = nn.Sequential(*group_list) + esa_channel = max(channel_num // 4, 16) + self.esa = ESA(esa_channel, channel_num) + + def forward(self, x): + out = self.residual_layer(x) + out = out + x + return self.esa(out) diff --git a/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py new file mode 100644 index 000000000..dec169520 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: OmniSR.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Sunday, 23rd April 2023 3:06:36 pm +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .OSAG import OSAG +from .pixelshuffle import pixelshuffle_block + + +class OmniSR(nn.Module): + def __init__( + self, + state_dict, + **kwargs, + ): + super(OmniSR, self).__init__() + self.state = state_dict + + bias = True # Fine to assume this for now + block_num = 1 # Fine to assume this for now + ffn_bias = True + pe = True + + num_feat = state_dict["input.weight"].shape[0] or 64 + num_in_ch = state_dict["input.weight"].shape[1] or 3 + num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh + + pixelshuffle_shape = state_dict["up.0.weight"].shape[0] + up_scale = math.sqrt(pixelshuffle_shape / num_out_ch) + if up_scale - int(up_scale) > 0: + print( + "out_nc is probably different than in_nc, scale calculation might be wrong" + ) + up_scale = int(up_scale) + res_num = 0 + for key in state_dict.keys(): + if "residual_layer" in key: + temp_res_num = int(key.split(".")[1]) + if temp_res_num > res_num: + res_num = temp_res_num + res_num = res_num + 1 # zero-indexed + + residual_layer = [] + self.res_num = res_num + + self.window_size = 8 # we can just assume this for now, but there's probably a way to calculate it (just need to get the sqrt of the right layer) + self.up_scale = up_scale + + for _ in range(res_num): + temp_res = OSAG( + channel_num=num_feat, + bias=bias, + block_num=block_num, + ffn_bias=ffn_bias, + window_size=self.window_size, + pe=pe, + ) + residual_layer.append(temp_res) + self.residual_layer = nn.Sequential(*residual_layer) + self.input = nn.Conv2d( + in_channels=num_in_ch, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.output = nn.Conv2d( + in_channels=num_feat, + out_channels=num_feat, + kernel_size=3, + stride=1, + padding=1, + bias=bias, + ) + self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias) + + # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, sqrt(2. / n)) + + # chaiNNer specific stuff + self.model_arch = "OmniSR" + self.sub_type = "SR" + self.in_nc = num_in_ch + self.out_nc = num_out_ch + self.num_feat = num_feat + self.scale = up_scale + + self.supports_fp16 = True # TODO: Test this + self.supports_bfp16 = True + self.min_size_restriction = 16 + + self.load_state_dict(state_dict, strict=False) + + def check_image_size(self, x): + _, _, h, w = x.size() + # import pdb; pdb.set_trace() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0) + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + residual = self.input(x) + out = self.residual_layer(residual) + + # origin + out = torch.add(self.output(out), residual) + out = self.up(out) + + out = out[:, :, : H * self.up_scale, : W * self.up_scale] + return out diff --git a/comfy_extras/chainner_models/architecture/OmniSR/esa.py b/comfy_extras/chainner_models/architecture/OmniSR/esa.py new file mode 100644 index 000000000..f9ce7f7a6 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/esa.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: esa.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:06 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .layernorm import LayerNorm2d + + +def moment(x, dim=(2, 3), k=2): + assert len(x.size()) == 4 + mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1) + mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim) + return mk + + +class ESA(nn.Module): + """ + Modification of Enhanced Spatial Attention (ESA), which is proposed by + `Residual Feature Aggregation Network for Image Super-Resolution` + Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes + are deleted. + """ + + def __init__(self, esa_channels, n_feats, conv=nn.Conv2d): + super(ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0) + self.conv3 = conv(f, f, kernel_size=3, padding=1) + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + c1 = self.conv2(c1_) + v_max = F.max_pool2d(c1, kernel_size=7, stride=3) + c3 = self.conv3(v_max) + c3 = F.interpolate( + c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False + ) + cf = self.conv_f(c1_) + c4 = self.conv4(c3 + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.conv1(x) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class LK_ESA_LN(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(LK_ESA_LN, self).__init__() + f = esa_channels + self.conv1 = conv(n_feats, f, kernel_size=1) + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.norm = LayerNorm2d(n_feats) + + self.vec_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=2, + bias=bias, + ) + self.vec_conv3x1 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(1, 3), + padding=(0, 1), + groups=2, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=2, + bias=bias, + ) + self.hor_conv1x3 = nn.Conv2d( + in_channels=f * kernel_expand, + out_channels=f * kernel_expand, + kernel_size=(3, 1), + padding=(1, 0), + groups=2, + bias=bias, + ) + + self.conv4 = conv(f, n_feats, kernel_size=1) + self.sigmoid = nn.Sigmoid() + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + c1_ = self.norm(x) + c1_ = self.conv1(c1_) + + res = self.vec_conv(c1_) + self.vec_conv3x1(c1_) + res = self.hor_conv(res) + self.hor_conv1x3(res) + + cf = self.conv_f(c1_) + c4 = self.conv4(res + cf) + m = self.sigmoid(c4) + return x * m + + +class AdaGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaGuidedFilter, self).__init__() + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=n_feats, + out_channels=1, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + self.r = 5 + + def box_filter(self, x, r): + channel = x.shape[1] + kernel_size = 2 * r + 1 + weight = 1.0 / (kernel_size**2) + box_kernel = weight * torch.ones( + (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device + ) + output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel) + return output + + def forward(self, x): + _, _, H, W = x.shape + N = self.box_filter( + torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r + ) + + # epsilon = self.fc(self.gap(x)) + # epsilon = torch.pow(epsilon, 2) + epsilon = 1e-2 + + mean_x = self.box_filter(x, self.r) / N + var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x + + A = var_x / (var_x + epsilon) + b = (1 - A) * mean_x + m = A * x + b + + # mean_A = self.box_filter(A, self.r) / N + # mean_b = self.box_filter(b, self.r) / N + # m = mean_A * x + mean_b + return x * m + + +class AdaConvGuidedFilter(nn.Module): + def __init__( + self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True + ): + super(AdaConvGuidedFilter, self).__init__() + f = esa_channels + + self.conv_f = conv(f, f, kernel_size=1) + + kernel_size = 17 + kernel_expand = kernel_expand + padding = kernel_size // 2 + + self.vec_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(1, kernel_size), + padding=(0, padding), + groups=f, + bias=bias, + ) + + self.hor_conv = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=(kernel_size, 1), + padding=(padding, 0), + groups=f, + bias=bias, + ) + + self.gap = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d( + in_channels=f, + out_channels=f, + kernel_size=1, + padding=0, + stride=1, + groups=1, + bias=True, + ) + + def forward(self, x): + y = self.vec_conv(x) + y = self.hor_conv(y) + + sigma = torch.pow(y, 2) + epsilon = self.fc(self.gap(y)) + + weight = sigma / (sigma + epsilon) + + m = weight * x + (1 - weight) + + return x * m diff --git a/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py new file mode 100644 index 000000000..731a25f75 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: layernorm.py +# Created Date: Tuesday April 28th 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Thursday, 20th April 2023 9:28:20 am +# Modified By: Chen Xuanhong +# Copyright (c) 2020 Shanghai Jiao Tong University +############################################################# + +import torch +import torch.nn as nn + + +class LayerNormFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias, eps): + ctx.eps = eps + N, C, H, W = x.size() + mu = x.mean(1, keepdim=True) + var = (x - mu).pow(2).mean(1, keepdim=True) + y = (x - mu) / (var + eps).sqrt() + ctx.save_for_backward(y, var, weight) + y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) + return y + + @staticmethod + def backward(ctx, grad_output): + eps = ctx.eps + + N, C, H, W = grad_output.size() + y, var, weight = ctx.saved_variables + g = grad_output * weight.view(1, C, 1, 1) + mean_g = g.mean(dim=1, keepdim=True) + + mean_gy = (g * y).mean(dim=1, keepdim=True) + gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) + return ( + gx, + (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), + grad_output.sum(dim=3).sum(dim=2).sum(dim=0), + None, + ) + + +class LayerNorm2d(nn.Module): + def __init__(self, channels, eps=1e-6): + super(LayerNorm2d, self).__init__() + self.register_parameter("weight", nn.Parameter(torch.ones(channels))) + self.register_parameter("bias", nn.Parameter(torch.zeros(channels))) + self.eps = eps + + def forward(self, x): + return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True) + Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x diff --git a/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py new file mode 100644 index 000000000..4260fb7c9 --- /dev/null +++ b/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding:utf-8 -*- +############################################################# +# File: pixelshuffle.py +# Created Date: Friday July 1st 2022 +# Author: Chen Xuanhong +# Email: chenxuanhongzju@outlook.com +# Last Modified: Friday, 1st July 2022 10:18:39 am +# Modified By: Chen Xuanhong +# Copyright (c) 2022 Shanghai Jiao Tong University +############################################################# + +import torch.nn as nn + + +def pixelshuffle_block( + in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False +): + """ + Upsample features according to `upscale_factor`. + """ + padding = kernel_size // 2 + conv = nn.Conv2d( + in_channels, + out_channels * (upscale_factor**2), + kernel_size, + padding=1, + bias=bias, + ) + pixel_shuffle = nn.PixelShuffle(upscale_factor) + return nn.Sequential(*[conv, pixel_shuffle]) diff --git a/comfy_extras/chainner_models/architecture/RRDB.py b/comfy_extras/chainner_models/architecture/RRDB.py index 4d52f05dd..b50db7c24 100644 --- a/comfy_extras/chainner_models/architecture/RRDB.py +++ b/comfy_extras/chainner_models/architecture/RRDB.py @@ -79,6 +79,12 @@ class RRDBNet(nn.Module): self.scale: int = self.get_scale() self.num_filters: int = self.state[self.key_arr[0]].shape[0] + c2x2 = False + if self.state["model.0.weight"].shape[-2] == 2: + c2x2 = True + self.scale = round(math.sqrt(self.scale / 4)) + self.model_arch = "ESRGAN-2c2" + self.supports_fp16 = True self.supports_bfp16 = True self.min_size_restriction = None @@ -105,11 +111,15 @@ class RRDBNet(nn.Module): out_nc=self.num_filters, upscale_factor=3, act_type=self.act, + c2x2=c2x2, ) else: upsample_blocks = [ upsample_block( - in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act + in_nc=self.num_filters, + out_nc=self.num_filters, + act_type=self.act, + c2x2=c2x2, ) for _ in range(int(math.log(self.scale, 2))) ] @@ -122,6 +132,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), B.ShortcutBlock( B.sequential( @@ -138,6 +149,7 @@ class RRDBNet(nn.Module): act_type=self.act, mode="CNA", plus=self.plus, + c2x2=c2x2, ) for _ in range(self.num_blocks) ], @@ -149,6 +161,7 @@ class RRDBNet(nn.Module): norm_type=self.norm, act_type=None, mode=self.mode, + c2x2=c2x2, ), ) ), @@ -160,6 +173,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=self.act, + c2x2=c2x2, ), # hr_conv1 B.conv_block( @@ -168,6 +182,7 @@ class RRDBNet(nn.Module): kernel_size=3, norm_type=None, act_type=None, + c2x2=c2x2, ), ) diff --git a/comfy_extras/chainner_models/architecture/block.py b/comfy_extras/chainner_models/architecture/block.py index 214642cc4..d7bc5d227 100644 --- a/comfy_extras/chainner_models/architecture/block.py +++ b/comfy_extras/chainner_models/architecture/block.py @@ -141,6 +141,19 @@ def sequential(*args): ConvMode = Literal["CNA", "NAC", "CNAC"] +# 2x2x2 Conv Block +def conv_block_2c2( + in_nc, + out_nc, + act_type="relu", +): + return sequential( + nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1), + nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0), + act(act_type) if act_type else None, + ) + + def conv_block( in_nc: int, out_nc: int, @@ -153,12 +166,17 @@ def conv_block( norm_type: str | None = None, act_type: str | None = "relu", mode: ConvMode = "CNA", + c2x2=False, ): """ Conv layer with padding, normalization, activation mode: CNA --> Conv -> Norm -> Act NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16) """ + + if c2x2: + return conv_block_2c2(in_nc, out_nc, act_type=act_type) + assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) padding = get_valid_padding(kernel_size, dilation) p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None @@ -285,6 +303,7 @@ class RRDB(nn.Module): _convtype="Conv2D", _spectral_norm=False, plus=False, + c2x2=False, ): super(RRDB, self).__init__() self.RDB1 = ResidualDenseBlock_5C( @@ -298,6 +317,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB2 = ResidualDenseBlock_5C( nf, @@ -310,6 +330,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) self.RDB3 = ResidualDenseBlock_5C( nf, @@ -322,6 +343,7 @@ class RRDB(nn.Module): act_type, mode, plus=plus, + c2x2=c2x2, ) def forward(self, x): @@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module): act_type="leakyrelu", mode: ConvMode = "CNA", plus=False, + c2x2=False, ): super(ResidualDenseBlock_5C, self).__init__() @@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv2 = conv_block( nf + gc, @@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv3 = conv_block( nf + 2 * gc, @@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) self.conv4 = conv_block( nf + 3 * gc, @@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=act_type, mode=mode, + c2x2=c2x2, ) if mode == "CNA": last_act = None @@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module): norm_type=norm_type, act_type=last_act, mode=mode, + c2x2=c2x2, ) def forward(self, x): @@ -499,6 +527,7 @@ def upconv_block( norm_type: str | None = None, act_type="relu", mode="nearest", + c2x2=False, ): # Up conv # described in https://distill.pub/2016/deconv-checkerboard/ @@ -512,5 +541,6 @@ def upconv_block( pad_type=pad_type, norm_type=norm_type, act_type=act_type, + c2x2=c2x2, ) return sequential(upsample, conv) diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index 8234ac5d1..2e66e6247 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -32,6 +33,7 @@ def load_state_dict(state_dict) -> PyTorchModel: state_dict = state_dict["params"] state_dict_keys = list(state_dict.keys()) + # SRVGGNet Real-ESRGAN (v2) if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys: model = RealESRGANv2(state_dict) @@ -79,6 +81,9 @@ def load_state_dict(state_dict) -> PyTorchModel: # MAT elif "synthesis.first_stage.conv_first.conv.resample_filter" in state_dict_keys: model = MAT(state_dict) + # Omni-SR + elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys: + model = OmniSR(state_dict) # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1 else: try: diff --git a/comfy_extras/chainner_models/types.py b/comfy_extras/chainner_models/types.py index 8e2bef47a..1906c0c7f 100644 --- a/comfy_extras/chainner_models/types.py +++ b/comfy_extras/chainner_models/types.py @@ -6,6 +6,7 @@ from .architecture.face.restoreformer_arch import RestoreFormer from .architecture.HAT import HAT from .architecture.LaMa import LaMa from .architecture.MAT import MAT +from .architecture.OmniSR.OmniSR import OmniSR from .architecture.RRDB import RRDBNet as ESRGAN from .architecture.SPSR import SPSRNet as SPSR from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2 @@ -13,7 +14,7 @@ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN from .architecture.Swin2SR import Swin2SR from .architecture.SwinIR import SwinIR -PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT) +PyTorchSRModels = (RealESRGANv2, SPSR, SwiftSRGAN, ESRGAN, SwinIR, Swin2SR, HAT, OmniSR) PyTorchSRModel = Union[ RealESRGANv2, SPSR, @@ -22,6 +23,7 @@ PyTorchSRModel = Union[ SwinIR, Swin2SR, HAT, + OmniSR, ] From 9b1396e93a19748dd4c4bb35637638bb0f91b5f0 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 24 May 2023 14:01:11 -0400 Subject: [PATCH 10/48] Fix issue importing other ui prompts. --- web/scripts/pnginfo.js | 1 + 1 file changed, 1 insertion(+) diff --git a/web/scripts/pnginfo.js b/web/scripts/pnginfo.js index 8ddb7a1c5..977b5ac2f 100644 --- a/web/scripts/pnginfo.js +++ b/web/scripts/pnginfo.js @@ -69,6 +69,7 @@ export async function importA1111(graph, parameters) { const embeddings = await api.getEmbeddings(); const opts = parameters .substr(p) + .split("\n")[1] .split(",") .reduce((p, n) => { const s = n.split(":"); From 8b4b0c3188110e1faa8865570637172ab4b60ba1 Mon Sep 17 00:00:00 2001 From: BlenderNeko <126974546+BlenderNeko@users.noreply.github.com> Date: Thu, 25 May 2023 19:23:47 +0200 Subject: [PATCH 11/48] vecorized bislerp --- comfy/utils.py | 117 +++++++++++++++++++++++++++---------------------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..cc0e5069a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -46,71 +47,81 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd -#slow and inefficient, should be optimized def bislerp(samples, width, height): - shape = list(samples.shape) - width_scale = (shape[3]) / (width ) - height_scale = (shape[2]) / (height ) + def slerp(b1, b2, r): + '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' + + c = b1.shape[-1] - shape[3] = width - shape[2] = height - out1 = torch.empty(shape, dtype=samples.dtype, layout=samples.layout, device=samples.device) + #norms + b1_norms = torch.norm(b1, dim=-1, keepdim=True) + b2_norms = torch.norm(b2, dim=-1, keepdim=True) - def algorithm(in1, in2, t): - dims = in1.shape - val = t + #normalize + b1_normalized = b1 / b1_norms + b2_normalized = b2 / b2_norms - #flatten to batches - low = in1.reshape(dims[0], -1) - high = in2.reshape(dims[0], -1) + #zero when norms are zero + b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0 + b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0 - low_weight = torch.norm(low, dim=1, keepdim=True) - low_weight[low_weight == 0] = 0.0000000001 - low_norm = low/low_weight - high_weight = torch.norm(high, dim=1, keepdim=True) - high_weight[high_weight == 0] = 0.0000000001 - high_norm = high/high_weight - - dot_prod = (low_norm*high_norm).sum(1) - dot_prod[dot_prod > 0.9995] = 0.9995 - dot_prod[dot_prod < -0.9995] = -0.9995 - omega = torch.acos(dot_prod) + #slerp + dot = (b1_normalized*b2_normalized).sum(1) + omega = torch.acos(dot) so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low_norm + (torch.sin(val*omega)/so).unsqueeze(1) * high_norm - res *= (low_weight * (1.0-val) + high_weight * val) - return res.reshape(dims) - for x_dest in range(shape[3]): - for y_dest in range(shape[2]): - y = (y_dest + 0.5) * height_scale - 0.5 - x = (x_dest + 0.5) * width_scale - 0.5 + #technically not mathematically correct, but more pleasing? + res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized + res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c) - x1 = max(math.floor(x), 0) - x2 = min(x1 + 1, samples.shape[3] - 1) - wx = x - math.floor(x) + #edge cases for same or polar opposites + res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] + res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] + return res + + def generate_bilinear_data(length_old, length_new): + coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") + ratios = coords_1 - coords_1.floor() + coords_1 = coords_1.to(torch.int64) + + coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2[:,:,:,-1] -= 1 + coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") + coords_2 = coords_2.to(torch.int64) + return ratios, coords_1, coords_2 + + n,c,h,w = samples.shape + h_new, w_new = (height, width) + + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) - y1 = max(math.floor(y), 0) - y2 = min(y1 + 1, samples.shape[2] - 1) - wy = y - math.floor(y) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - in1 = samples[:,:,y1,x1] - in2 = samples[:,:,y1,x2] - in3 = samples[:,:,y2,x1] - in4 = samples[:,:,y2,x2] + pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') - if (x1 == x2) and (y1 == y2): - out_value = in1 - elif (x1 == x2): - out_value = algorithm(in1, in3, wy) - elif (y1 == y2): - out_value = algorithm(in1, in2, wx) - else: - o1 = algorithm(in1, in2, wx) - o2 = algorithm(in3, in4, wx) - out_value = algorithm(o1, o2, wy) + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) - out1[:,:,y_dest,x_dest] = out_value - return out1 + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + + coords_1 = coords_1.expand((n, c, h_new, -1)) + coords_2 = coords_2.expand((n, c, h_new, -1)) + ratios = ratios.expand((n, 1, h_new, -1)) + + pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + + result = slerp(pass_1, pass_2, ratios) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + return result def common_upscale(samples, width, height, upscale_method, crop): if crop == "center": From e1278fa925cf59350bae76dc3d0c59a0e9564789 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 13:30:59 -0400 Subject: [PATCH 12/48] Support old pytorch versions that don't have weights_only. --- comfy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/utils.py b/comfy/utils.py index 300eda6aa..d58320b4a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -6,6 +6,10 @@ def load_torch_file(ckpt, safe_load=False): import safetensors.torch sd = safetensors.torch.load_file(ckpt, device="cpu") else: + if safe_load: + if not 'weights_only' in torch.load.__code__.co_varnames: + print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.") + safe_load = False if safe_load: pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True) else: From 87ab25fac77ff1d558fea3c02733a463cb1fa013 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:31:27 -0400 Subject: [PATCH 13/48] Do operations in same order as the one it replaces. --- comfy/utils.py | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 33c1c3dd7..f139fbb27 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -98,29 +98,27 @@ def bislerp(samples, width, height): n,c,h,w = samples.shape h_new, w_new = (height, width) - #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + #linear w + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + coords_1 = coords_1.expand((n, c, h, -1)) + coords_2 = coords_2.expand((n, c, h, -1)) + ratios = ratios.expand((n, 1, h, -1)) - coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w)) - coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w)) - ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w)) - - pass_1 = einops.rearrange(samples.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-2,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w) + result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) - #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + #linear h + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) + ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - coords_1 = coords_1.expand((n, c, h_new, -1)) - coords_2 = coords_2.expand((n, c, h_new, -1)) - ratios = ratios.expand((n, 1, h_new, -1)) - - pass_1 = einops.rearrange(result.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-1,coords_2), 'n c h w -> (n h w) c') + pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') + pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') result = slerp(pass_1, pass_2, ratios) From eb4bd7711acec9a2a2d4f1d4dcc1d32e1236c976 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 25 May 2023 18:42:56 -0400 Subject: [PATCH 14/48] Remove einops. --- comfy/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index f139fbb27..5ed9aaa02 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,6 +1,5 @@ import torch import math -import einops def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -104,12 +103,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) - pass_1 = einops.rearrange(samples.gather(-1,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(samples.gather(-1,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = samples.gather(-1,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = samples.gather(-1,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h, w=w_new) + result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) @@ -117,12 +116,12 @@ def bislerp(samples, width, height): coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) - pass_1 = einops.rearrange(result.gather(-2,coords_1), 'n c h w -> (n h w) c') - pass_2 = einops.rearrange(result.gather(-2,coords_2), 'n c h w -> (n h w) c') - ratios = einops.rearrange(ratios, 'n c h w -> (n h w) c') + pass_1 = result.gather(-2,coords_1).movedim(1, -1).reshape((-1,c)) + pass_2 = result.gather(-2,coords_2).movedim(1, -1).reshape((-1,c)) + ratios = ratios.movedim(1, -1).reshape((-1,1)) result = slerp(pass_1, pass_2, ratios) - result = einops.rearrange(result, '(n h w) c -> n c h w',n=n, h=h_new, w=w_new) + result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) return result def common_upscale(samples, width, height, upscale_method, crop): From 4d1ed829d9a934d9a303a725e325f90934854ac8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 19:33:30 -0500 Subject: [PATCH 15/48] Don't load some model types if weight is zero --- nodes.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/nodes.py b/nodes.py index f0a93ebd5..68010f040 100644 --- a/nodes.py +++ b/nodes.py @@ -426,6 +426,9 @@ class LoraLoader: CATEGORY = "loaders" def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + lora_path = folder_paths.get_full_path("loras", lora_name) model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora_path, strength_model, strength_clip) return (model_lora, clip_lora) @@ -507,6 +510,9 @@ class ControlNetApply: CATEGORY = "conditioning" def apply_controlnet(self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning, ) + c = [] control_hint = image.movedim(-1,1) for t in conditioning: @@ -613,6 +619,9 @@ class unCLIPConditioning: CATEGORY = "conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): + if strength == 0: + return (conditioning, ) + c = [] for t in conditioning: o = t[1].copy() From 679bd2845af8e22b2802cf326b99b40a26ba7811 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Fri, 26 May 2023 21:46:11 -0400 Subject: [PATCH 16/48] Safetensors isn't optional anymore. --- folder_paths.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index 28f117824..20b461c94 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -1,14 +1,7 @@ import os -supported_ckpt_extensions = set(['.ckpt', '.pth']) -supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth']) -try: - import safetensors.torch - supported_ckpt_extensions.add('.safetensors') - supported_pt_extensions.add('.safetensors') -except: - print("Could not import safetensors, safetensors support disabled.") - +supported_ckpt_extensions = set(['.ckpt', '.pth', '.safetensors']) +supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors']) folder_names_and_paths = {} From 73e85fb3f4b104053fb1ac5d0aea456e373ea8c8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:00:47 -0500 Subject: [PATCH 17/48] Improve error output for failed nodes --- execution.py | 237 ++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 33 deletions(-) diff --git a/execution.py b/execution.py index 25f2fcacd..691beb102 100644 --- a/execution.py +++ b/execution.py @@ -297,24 +297,80 @@ def validate_inputs(prompt, item, validated): class_inputs = obj_class.INPUT_TYPES() required_inputs = class_inputs['required'] + + errors = [] + valid = True + for x in required_inputs: if x not in inputs: - return (False, "Required input is missing. {}, {}".format(class_type, x), unique_id) + error = { + "type": "required_input_missing", + "message": "Required input is missing", + "details": f"{x}", + "extra_info": { + "input_name": x + } + } + errors.append(error) + continue + val = inputs[x] info = required_inputs[x] type_input = info[0] if isinstance(val, list): if len(val) != 2: - return (False, "Bad Input. {}, {}".format(class_type, x), unique_id) + error = { + "type": "bad_linked_input", + "message": "Bad linked input, must be a length-2 list of [node_id, slot_index]", + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val + } + } + errors.append(error) + continue + o_id = val[0] o_class_type = prompt[o_id]['class_type'] r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES if r[val[1]] != type_input: - return (False, "Return type mismatch. {}, {}, {} != {}".format(class_type, x, r[val[1]], type_input), unique_id) - r = validate_inputs(prompt, o_id, validated) - if r[0] == False: - validated[o_id] = r - return r + received_type = r[val[1]] + details = f"{x}, {received_type} != {type_input}" + error = { + "type": "return_type_mismatch", + "message": "Return type mismatch between linked nodes", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_type": received_type + } + } + errors.append(error) + continue + try: + r = validate_inputs(prompt, o_id, validated) + if r[0] is False: + # `r` will be set in `validated[o_id]` already + valid = False + continue + except Exception as ex: + typ, _, tb = sys.exc_info() + valid = False + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o_id] = (False, reasons, o_id) + continue else: if type_input == "INT": val = int(val) @@ -328,26 +384,97 @@ def validate_inputs(prompt, item, validated): if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: - return (False, "Value {} smaller than min of {}. {}, {}".format(val, info[1]["min"], class_type, x), unique_id) + error = { + "type": "value_smaller_than_min", + "message": "Value {} smaller than min of {}".format(val, info[1]["min"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if "max" in info[1] and val > info[1]["max"]: - return (False, "Value {} bigger than max of {}. {}, {}".format(val, info[1]["max"], class_type, x), unique_id) + error = { + "type": "value_bigger_than_max", + "message": "Value {} bigger than max of {}".format(val, info[1]["max"]), + "details": f"{x}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue if hasattr(obj_class, "VALIDATE_INPUTS"): input_data_all = get_input_data(inputs, obj_class, unique_id) #ret = obj_class.VALIDATE_INPUTS(**input_data_all) ret = map_node_over_list(obj_class, input_data_all, "VALIDATE_INPUTS") - for r in ret: - if r != True: - return (False, "{}, {}".format(class_type, r), unique_id) + for i, r in enumerate(ret): + if r is not True: + details = f"{x}" + if r is not False: + details += f": {str(r)}" + else: + details += "." + + error = { + "type": "custom_validation_failed", + "message": "Custom validation failed for node", + "details": details, + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + } + } + errors.append(error) + continue else: if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input), unique_id) + input_config = info + list_info = "" + + # Don't send back gigantic lists like if they're lots of + # scanned model filepaths + if len(type_input) > 20: + list_info = f"(list of length {len(type_input)})" + input_config = None + else: + list_info = str(type_input) + + error = { + "type": "value_not_in_list", + "message": "Value not in list", + "details": f"{x}: '{val}' not in {list_info}", + "extra_info": { + "input_name": x, + "input_config": input_config, + "received_value": val, + } + } + errors.append(error) + continue + + if len(errors) > 0 or valid is not True: + ret = (False, errors, unique_id) + else: + ret = (True, [], unique_id) - ret = (True, "", unique_id) validated[unique_id] = ret return ret +def full_type_name(klass): + module = klass.__module__ + if module == 'builtins': + return klass.__qualname__ + return module + '.' + klass.__qualname__ + def validate_prompt(prompt): outputs = set() for x in prompt: @@ -356,7 +483,13 @@ def validate_prompt(prompt): outputs.add(x) if len(outputs) == 0: - return (False, "Prompt has no outputs", [], []) + error = { + "type": "prompt_no_outputs", + "message": "Prompt has no outputs", + "details": "", + "extra_info": {} + } + return (False, error, [], []) good_outputs = set() errors = [] @@ -364,34 +497,72 @@ def validate_prompt(prompt): validated = {} for o in outputs: valid = False - reason = "" + reasons = [] try: m = validate_inputs(prompt, o, validated) valid = m[0] - reason = m[1] - node_id = m[2] - except Exception as e: - print(traceback.format_exc()) + reasons = m[1] + except Exception as ex: + typ, _, tb = sys.exc_info() valid = False - reason = "Parsing error" - node_id = None + error_type = full_type_name(typ) + reasons = [{ + "type": "exception_during_validation", + "message": "Exception when validating node", + "details": str(ex), + "extra_info": { + "error_type": error_type, + "traceback": traceback.format_tb(tb) + } + }] + validated[o] = (False, reasons, o) - if valid == True: + if valid is True: good_outputs.add(o) else: - print("Failed to validate prompt for output {} {}".format(o, reason)) - print("output will be ignored") - errors += [(o, reason)] - if node_id is not None: - if node_id not in node_errors: - node_errors[node_id] = {"message": reason, "dependent_outputs": []} - node_errors[node_id]["dependent_outputs"].append(o) + print(f"Failed to validate prompt for output {o}:") + if len(reasons) > 0: + print("* (prompt):") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + errors += [(o, reasons)] + for node_id, result in validated.items(): + valid = result[0] + reasons = result[1] + # If a node upstream has errors, the nodes downstream will also + # be reported as invalid, but there will be no errors attached. + # So don't return those nodes as having errors in the response. + if valid is not True and len(reasons) > 0: + if node_id not in node_errors: + class_type = prompt[node_id]['class_type'] + node_errors[node_id] = { + "errors": reasons, + "dependent_outputs": [], + "class_type": class_type + } + print(f"* {class_type} {node_id}:") + for reason in reasons: + print(f" - {reason['message']}: {reason['details']}") + node_errors[node_id]["dependent_outputs"].append(o) + print("Output will be ignored") if len(good_outputs) == 0: - errors_list = "\n".join(set(map(lambda a: "{}".format(a[1]), errors))) - return (False, "Prompt has no properly connected outputs\n {}".format(errors_list), list(good_outputs), node_errors) + errors_list = [] + for o, errors in errors: + for error in errors: + errors_list.append(f"{error['message']}: {error['details']}") + errors_list = "\n".join(errors_list) - return (True, "", list(good_outputs), node_errors) + error = { + "type": "prompt_no_good_outputs", + "message": "Prompt has no properly connected outputs", + "details": errors_list, + "extra_info": {} + } + + return (False, error, list(good_outputs), node_errors) + + return (True, None, list(good_outputs), node_errors) class PromptQueue: From cc4d3435d3590288e21f3adfd42f044a7e45fae4 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:48:55 -0500 Subject: [PATCH 18/48] Highlight failing nodes/inputs in frontend --- web/scripts/app.js | 74 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 67 insertions(+), 7 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 97b7c8d31..21fe94802 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -771,16 +771,25 @@ export class ComfyApp { LGraphCanvas.prototype.drawNodeShape = function (node, ctx, size, fgcolor, bgcolor, selected, mouse_over) { const res = origDrawNodeShape.apply(this, arguments); + const nodeErrors = self.lastPromptError?.node_errors[node.id]; + let color = null; + let lineWidth = 1; if (node.id === +self.runningNodeId) { color = "#0f0"; } else if (self.dragOverNode && node.id === self.dragOverNode.id) { color = "dodgerblue"; } + else if (self.lastPromptError != null && nodeErrors?.errors) { + color = "red"; + lineWidth = 2; + } + + self.graphTime = Date.now() if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; - ctx.lineWidth = 1; + ctx.lineWidth = lineWidth; ctx.globalAlpha = 0.8; ctx.beginPath(); if (shape == LiteGraph.BOX_SHAPE) @@ -807,11 +816,28 @@ export class ComfyApp { ctx.stroke(); ctx.strokeStyle = fgcolor; ctx.globalAlpha = 1; + } - if (self.progress) { - ctx.fillStyle = "green"; - ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); - ctx.fillStyle = bgcolor; + if (self.progress && node.id === +self.runningNodeId) { + ctx.fillStyle = "green"; + ctx.fillRect(0, 0, size[0] * (self.progress.value / self.progress.max), 6); + ctx.fillStyle = bgcolor; + } + + // Highlight inputs that failed validation + if (nodeErrors) { + ctx.lineWidth = 2; + ctx.strokeStyle = "red"; + for (const error of nodeErrors.errors) { + if (error.extra_info && error.extra_info.input_name) { + const inputIndex = node.findInputSlot(error.extra_info.input_name) + if (inputIndex !== -1) { + let pos = node.getConnectionPos(true, inputIndex); + ctx.beginPath(); + ctx.arc(pos[0] - node.pos[0], pos[1] - node.pos[1], 12, 0, 2 * Math.PI, false) + ctx.stroke(); + } + } } } @@ -1243,6 +1269,31 @@ export class ComfyApp { return { workflow, output }; } + #formatError(error) { + if (error == null) { + return "(unknown error)" + } + else if (typeof error === "string") { + return error; + } + else if (error.stack && error.message) { + return error.toString() + } + else if (error.response) { + let message = error.response.error.message; + if (error.response.error.details) + message += ": " + error.response.error.details; + for (const [nodeID, nodeError] of Object.entries(error.response.node_errors)) { + message += "\n" + nodeError.class_type + ":" + for (const errorReason of nodeError.errors) { + message += "\n - " + errorReason.message + ": " + errorReason.details + } + } + return message + } + return "(unknown error)" + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1250,8 +1301,10 @@ export class ComfyApp { if (this.#processingQueue) { return; } - + this.#processingQueue = true; + this.lastPromptError = null; + try { while (this.#queueItems.length) { ({ number, batchCount } = this.#queueItems.pop()); @@ -1262,7 +1315,12 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - this.ui.dialog.show(error.response.error || error.toString()); + const formattedError = this.#formatError(error) + this.ui.dialog.show(formattedError); + if (error.response) { + this.lastPromptError = error.response; + this.canvas.draw(true, true); + } break; } @@ -1360,6 +1418,8 @@ export class ComfyApp { */ clean() { this.nodeOutputs = {}; + this.lastPromptError = null; + this.graphTime = null } } From c33b7c5549b7b277011e2c3f50215ba466afb205 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:54:13 -0500 Subject: [PATCH 19/48] Improve invalid prompt error message --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 691beb102..66753ff90 100644 --- a/execution.py +++ b/execution.py @@ -554,8 +554,8 @@ def validate_prompt(prompt): errors_list = "\n".join(errors_list) error = { - "type": "prompt_no_good_outputs", - "message": "Prompt has no properly connected outputs", + "type": "prompt_outputs_failed_validation", + "message": "Prompt outputs failed validation", "details": errors_list, "extra_info": {} } From 0d834e3a2ba6272b8cee6503f574c0f06002ddc3 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 11:59:30 -0500 Subject: [PATCH 20/48] Add missing input name/config --- execution.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/execution.py b/execution.py index 66753ff90..632aaa843 100644 --- a/execution.py +++ b/execution.py @@ -365,6 +365,8 @@ def validate_inputs(prompt, item, validated): "message": "Exception when validating node", "details": str(ex), "extra_info": { + "input_name": x, + "input_config": info, "error_type": error_type, "traceback": traceback.format_tb(tb) } From ffec815257ddf2371b880eafd575838210fcea07 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 12:48:06 -0500 Subject: [PATCH 21/48] Send back more information about exceptions that happen during execution --- execution.py | 173 ++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/execution.py b/execution.py index 632aaa843..5ed9ff348 100644 --- a/execution.py +++ b/execution.py @@ -102,13 +102,19 @@ def get_output_data(obj, input_data_all): ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()} return output, ui +def format_value(x): + if isinstance(x, (int, float, bool, str)): + return x + else: + return str(x) + def recursive_execute(server, prompt, outputs, current_item, extra_data, executed, prompt_id, outputs_ui): unique_id = current_item inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] if unique_id in outputs: - return + return (True, None, None) for x in inputs: input_data = inputs[x] @@ -117,22 +123,64 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute input_unique_id = input_data[0] output_index = input_data[1] if input_unique_id not in outputs: - recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + result = recursive_execute(server, prompt, outputs, input_unique_id, extra_data, executed, prompt_id, outputs_ui) + if result[0] is not True: + # Another node failed further upstream + return result - input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) - if server.client_id is not None: - server.last_node_id = unique_id - server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) - obj = class_def() - - output_data, output_ui = get_output_data(obj, input_data_all) - outputs[unique_id] = output_data - if len(output_ui) > 0: - outputs_ui[unique_id] = output_ui + input_data_all = None + try: + input_data_all = get_input_data(inputs, class_def, unique_id, outputs, prompt, extra_data) if server.client_id is not None: - server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + server.last_node_id = unique_id + server.send_sync("executing", { "node": unique_id, "prompt_id": prompt_id }, server.client_id) + obj = class_def() + + output_data, output_ui = get_output_data(obj, input_data_all) + outputs[unique_id] = output_data + if len(output_ui) > 0: + outputs_ui[unique_id] = output_ui + if server.client_id is not None: + server.send_sync("executed", { "node": unique_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) + except comfy.model_management.InterruptProcessingException as iex: + print("Processing interrupted") + + # skip formatting inputs/outputs + error_details = { + "node_id": unique_id, + } + + return (False, error_details, iex) + except Exception as ex: + typ, _, tb = sys.exc_info() + exception_type = full_type_name(typ) + input_data_formatted = {} + if input_data_all is not None: + input_data_formatted = {} + for name, inputs in input_data_all.items(): + input_data_formatted[name] = [format_value(x) for x in inputs] + + output_data_formatted = {} + for node_id, node_outputs in outputs.items(): + output_data_formatted[node_id] = [[format_value(x) for x in l] for l in node_outputs] + + print("!!! Exception during processing !!!") + print(traceback.format_exc()) + + error_details = { + "node_id": unique_id, + "message": str(ex), + "exception_type": exception_type, + "traceback": traceback.format_tb(tb), + "current_inputs": input_data_formatted, + "current_outputs": output_data_formatted + } + return (False, error_details, ex) + executed.add(unique_id) + return (True, None, None) + def recursive_will_execute(prompt, outputs, current_item): unique_id = current_item inputs = prompt[unique_id]['inputs'] @@ -210,6 +258,44 @@ class PromptExecutor: self.old_prompt = {} self.server = server + def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + # First, send back the status to the frontend depending + # on the exception type + if isinstance(ex, comfy.model_management.InterruptProcessingException): + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "node_id": error["node_id"], + } + self.server.send_sync("execution_interrupted", mes, self.server.client_id) + else: + if self.server.client_id is not None: + mes = { + "prompt_id": prompt_id, + "executed": list(executed), + + "message": error["message"], + "exception_type": error["exception_type"], + "traceback": error["traceback"], + "node_id": error["node_id"], + "current_inputs": error["current_inputs"], + "current_outputs": error["current_outputs"], + } + self.server.send_sync("execution_error", mes, self.server.client_id) + + # Next, remove the subsequent outputs since they will not be executed + to_delete = [] + for o in self.outputs: + if (o not in current_outputs) and (o not in executed): + to_delete += [o] + if o in self.old_prompt: + d = self.old_prompt.pop(o) + del d + for o in to_delete: + d = self.outputs.pop(o) + del d + def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): nodes.interrupt_processing(False) @@ -244,42 +330,29 @@ class PromptExecutor: if self.server.client_id is not None: self.server.send_sync("execution_cached", { "nodes": list(current_outputs) , "prompt_id": prompt_id}, self.server.client_id) executed = set() - try: - to_execute = [] - for x in list(execute_outputs): - to_execute += [(0, x)] + output_node_id = None + to_execute = [] - while len(to_execute) > 0: - #always execute the output that depends on the least amount of unexecuted nodes first - to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) - x = to_execute.pop(0)[-1] + for node_id in list(execute_outputs): + to_execute += [(0, node_id)] - recursive_execute(self.server, prompt, self.outputs, x, extra_data, executed, prompt_id, self.outputs_ui) - except Exception as e: - if isinstance(e, comfy.model_management.InterruptProcessingException): - print("Processing interrupted") - else: - message = str(traceback.format_exc()) - print(message) - if self.server.client_id is not None: - self.server.send_sync("execution_error", { "message": message, "prompt_id": prompt_id }, self.server.client_id) + while len(to_execute) > 0: + #always execute the output that depends on the least amount of unexecuted nodes first + to_execute = sorted(list(map(lambda a: (len(recursive_will_execute(prompt, self.outputs, a[-1])), a[-1]), to_execute))) + output_node_id = to_execute.pop(0)[-1] - to_delete = [] - for o in self.outputs: - if (o not in current_outputs) and (o not in executed): - to_delete += [o] - if o in self.old_prompt: - d = self.old_prompt.pop(o) - del d - for o in to_delete: - d = self.outputs.pop(o) - del d - finally: - for x in executed: - self.old_prompt[x] = copy.deepcopy(prompt[x]) - self.server.last_node_id = None - if self.server.client_id is not None: - self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) + # This call shouldn't raise anything if there's an error deep in + # the actual SD code, instead it will report the node where the + # error was raised + success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) + if success is not True: + self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + + for x in executed: + self.old_prompt[x] = copy.deepcopy(prompt[x]) + self.server.last_node_id = None + if self.server.client_id is not None: + self.server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, self.server.client_id) print("Prompt executed in {:.2f} seconds".format(time.perf_counter() - execution_start_time)) gc.collect() @@ -359,7 +432,7 @@ def validate_inputs(prompt, item, validated): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", @@ -367,7 +440,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] @@ -507,13 +580,13 @@ def validate_prompt(prompt): except Exception as ex: typ, _, tb = sys.exc_info() valid = False - error_type = full_type_name(typ) + exception_type = full_type_name(typ) reasons = [{ "type": "exception_during_validation", "message": "Exception when validating node", "details": str(ex), "extra_info": { - "error_type": error_type, + "exception_type": exception_type, "traceback": traceback.format_tb(tb) } }] From 6b2a8a3845972bcff02184aaa8ded6eace8300ad Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:03:41 -0500 Subject: [PATCH 22/48] Show message in the frontend if prompt execution raises an exception --- execution.py | 14 +++++++++----- web/scripts/api.js | 6 ++++++ web/scripts/app.js | 35 ++++++++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/execution.py b/execution.py index 5ed9ff348..f79c3d351 100644 --- a/execution.py +++ b/execution.py @@ -258,27 +258,31 @@ class PromptExecutor: self.old_prompt = {} self.server = server - def handle_execution_error(self, prompt_id, current_outputs, executed, error, ex): + def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex): + node_id = error["node_id"] + class_type = prompt[node_id]["class_type"] + # First, send back the status to the frontend depending # on the exception type if isinstance(ex, comfy.model_management.InterruptProcessingException): mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), - - "node_id": error["node_id"], } self.server.send_sync("execution_interrupted", mes, self.server.client_id) else: if self.server.client_id is not None: mes = { "prompt_id": prompt_id, + "node_id": node_id, + "node_type": class_type, "executed": list(executed), "message": error["message"], "exception_type": error["exception_type"], "traceback": error["traceback"], - "node_id": error["node_id"], "current_inputs": error["current_inputs"], "current_outputs": error["current_outputs"], } @@ -346,7 +350,7 @@ class PromptExecutor: # error was raised success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: - self.handle_execution_error(prompt_id, current_outputs, executed, error, ex) + self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) diff --git a/web/scripts/api.js b/web/scripts/api.js index 4f061c358..378165b3a 100644 --- a/web/scripts/api.js +++ b/web/scripts/api.js @@ -88,6 +88,12 @@ class ComfyApi extends EventTarget { case "executed": this.dispatchEvent(new CustomEvent("executed", { detail: msg.data })); break; + case "execution_start": + this.dispatchEvent(new CustomEvent("execution_start", { detail: msg.data })); + break; + case "execution_error": + this.dispatchEvent(new CustomEvent("execution_error", { detail: msg.data })); + break; default: if (this.#registered.has(msg.type)) { this.dispatchEvent(new CustomEvent(msg.type, { detail: msg.data })); diff --git a/web/scripts/app.js b/web/scripts/app.js index 21fe94802..e8ab32cf9 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -784,8 +784,10 @@ export class ComfyApp { color = "red"; lineWidth = 2; } - - self.graphTime = Date.now() + else if (self.lastExecutionError && +self.lastExecutionError.node_id === node.id) { + color = "#f0f"; + lineWidth = 2; + } if (color) { const shape = node._shape || node.constructor.shape || LiteGraph.ROUND_SHAPE; @@ -895,6 +897,17 @@ export class ComfyApp { } }); + api.addEventListener("execution_start", ({ detail }) => { + this.lastExecutionError = null + }); + + api.addEventListener("execution_error", ({ detail }) => { + this.lastExecutionError = detail; + const formattedError = this.#formatExecutionError(detail); + this.ui.dialog.show(formattedError); + this.canvas.draw(true, true); + }); + api.init(); } @@ -1269,7 +1282,7 @@ export class ComfyApp { return { workflow, output }; } - #formatError(error) { + #formatPromptError(error) { if (error == null) { return "(unknown error)" } @@ -1294,6 +1307,18 @@ export class ComfyApp { return "(unknown error)" } + #formatExecutionError(error) { + if (error == null) { + return "(unknown error)" + } + + const traceback = error.traceback.join("") + const nodeId = error.node_id + const nodeType = error.node_type + + return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + } + async queuePrompt(number, batchCount = 1) { this.#queueItems.push({ number, batchCount }); @@ -1315,7 +1340,7 @@ export class ComfyApp { try { await api.queuePrompt(number, p); } catch (error) { - const formattedError = this.#formatError(error) + const formattedError = this.#formatPromptError(error) this.ui.dialog.show(formattedError); if (error.response) { this.lastPromptError = error.response; @@ -1419,7 +1444,7 @@ export class ComfyApp { clean() { this.nodeOutputs = {}; this.lastPromptError = null; - this.graphTime = null + this.lastExecutionError = null; } } From e2d080b6941783e50155f694c11ab0da1b1ae240 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:07:51 -0500 Subject: [PATCH 23/48] Return null for value format --- execution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index f79c3d351..9cebce928 100644 --- a/execution.py +++ b/execution.py @@ -103,7 +103,9 @@ def get_output_data(obj, input_data_all): return output, ui def format_value(x): - if isinstance(x, (int, float, bool, str)): + if x is None: + return None + elif isinstance(x, (int, float, bool, str)): return x else: return str(x) From a9e7e237248296c8fe0d79991e0f8c2c0f2cf530 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 25 May 2023 13:11:34 -0500 Subject: [PATCH 24/48] Fix --- execution.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/execution.py b/execution.py index 9cebce928..ffea00a8c 100644 --- a/execution.py +++ b/execution.py @@ -499,9 +499,7 @@ def validate_inputs(prompt, item, validated): if r is not True: details = f"{x}" if r is not False: - details += f": {str(r)}" - else: - details += "." + details += f" - {str(r)}" error = { "type": "custom_validation_failed", From 62bdd9d26aba086ffbeedd118140e2806e6f4345 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Fri, 26 May 2023 16:35:54 -0500 Subject: [PATCH 25/48] Catch typecast errors --- execution.py | 43 ++++++++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/execution.py b/execution.py index ffea00a8c..6af58a673 100644 --- a/execution.py +++ b/execution.py @@ -424,7 +424,8 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, - "received_type": received_type + "received_type": received_type, + "linked_node": val } } errors.append(error) @@ -440,28 +441,44 @@ def validate_inputs(prompt, item, validated): valid = False exception_type = full_type_name(typ) reasons = [{ - "type": "exception_during_validation", - "message": "Exception when validating node", + "type": "exception_during_inner_validation", + "message": "Exception when validating inner node", "details": str(ex), "extra_info": { "input_name": x, "input_config": info, "exception_type": exception_type, - "traceback": traceback.format_tb(tb) + "traceback": traceback.format_tb(tb), + "linked_node": val } }] validated[o_id] = (False, reasons, o_id) continue else: - if type_input == "INT": - val = int(val) - inputs[x] = val - if type_input == "FLOAT": - val = float(val) - inputs[x] = val - if type_input == "STRING": - val = str(val) - inputs[x] = val + try: + if type_input == "INT": + val = int(val) + inputs[x] = val + if type_input == "FLOAT": + val = float(val) + inputs[x] = val + if type_input == "STRING": + val = str(val) + inputs[x] = val + except Exception as ex: + error = { + "type": "invalid_input_type", + "message": f"Failed to convert an input value to a {type_input} value", + "details": f"{x}, {val}, {ex}", + "extra_info": { + "input_name": x, + "input_config": info, + "received_value": val, + "exception_message": str(ex) + } + } + errors.append(error) + continue if len(info) > 1: if "min" in info[1] and val < info[1]["min"]: From 52c9590b7b65dba86e8622f6ad38974bc4045f31 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 01:51:39 -0500 Subject: [PATCH 26/48] Exception message --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 6af58a673..52c264b0f 100644 --- a/execution.py +++ b/execution.py @@ -447,6 +447,7 @@ def validate_inputs(prompt, item, validated): "extra_info": { "input_name": x, "input_config": info, + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "linked_node": val From 03f2d0a764726641e848ba4e069c8809a502afdf Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 02:02:11 -0500 Subject: [PATCH 27/48] Rename exception message field --- execution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/execution.py b/execution.py index 52c264b0f..1a9a1ff73 100644 --- a/execution.py +++ b/execution.py @@ -171,7 +171,7 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute error_details = { "node_id": unique_id, - "message": str(ex), + "exception_message": str(ex), "exception_type": exception_type, "traceback": traceback.format_tb(tb), "current_inputs": input_data_formatted, @@ -282,7 +282,7 @@ class PromptExecutor: "node_type": class_type, "executed": list(executed), - "message": error["message"], + "exception_message": error["exception_message"], "exception_type": error["exception_type"], "traceback": error["traceback"], "current_inputs": error["current_inputs"], From 00646b0813e4f395725f3013f18b13a46f4d619d Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 27 May 2023 21:48:49 -0500 Subject: [PATCH 28/48] Bitwise operations for masks --- comfy_extras/nodes_mask.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9916f3b21..9134c24da 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -167,7 +167,7 @@ class MaskComposite: "source": ("MASK",), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract"],), + "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), } } @@ -193,6 +193,12 @@ class MaskComposite: output[top:bottom, left:right] = destination_portion + source_portion elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion + elif operation == "and": + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + elif operation == "or": + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + elif operation == "xor": + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() output = torch.clamp(output, 0.0, 1.0) From ad81fd682a5e5e7c1f258d7c11a000c0dfd07be3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:32:26 -0400 Subject: [PATCH 29/48] Fix issue with cancelling prompt. --- execution.py | 1 + 1 file changed, 1 insertion(+) diff --git a/execution.py b/execution.py index 1a9a1ff73..218a84c36 100644 --- a/execution.py +++ b/execution.py @@ -353,6 +353,7 @@ class PromptExecutor: success, error, ex = recursive_execute(self.server, prompt, self.outputs, output_node_id, extra_data, executed, prompt_id, self.outputs_ui) if success is not True: self.handle_execution_error(prompt_id, prompt, current_outputs, executed, error, ex) + break for x in executed: self.old_prompt[x] = copy.deepcopy(prompt[x]) From f3ac938b4a5c031adb9ee2951f26360d6a2b36de Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 00:42:53 -0400 Subject: [PATCH 30/48] Round the mask values for bitwise operations. --- comfy_extras/nodes_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 9134c24da..15377af14 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -194,11 +194,11 @@ class MaskComposite: elif operation == "subtract": output[top:bottom, left:right] = destination_portion - source_portion elif operation == "and": - output[top:bottom, left:right] = torch.bitwise_and(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "or": - output[top:bottom, left:right] = torch.bitwise_or(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() elif operation == "xor": - output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.bool(), source_portion.bool()).float() + output[top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() output = torch.clamp(output, 0.0, 1.0) From 0fc483dcfdef457b50d3a67e66b4f463e6ef9d62 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 01:52:09 -0400 Subject: [PATCH 31/48] Refactor diffusers model convert code to be able to reuse it. --- comfy/diffusers_convert.py | 107 ----------------------------------- comfy/diffusers_load.py | 111 +++++++++++++++++++++++++++++++++++++ nodes.py | 4 +- 3 files changed, 113 insertions(+), 109 deletions(-) create mode 100644 comfy/diffusers_load.py diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py index ceca80305..1eab54d4b 100644 --- a/comfy/diffusers_convert.py +++ b/comfy/diffusers_convert.py @@ -1,14 +1,5 @@ -import json -import os -import yaml - -import folder_paths -from comfy.ldm.util import instantiate_from_config -from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE -import os.path as osp import re import torch -from safetensors.torch import load_file, save_file # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -262,101 +253,3 @@ def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict -def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): - diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) - diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) - - # magic - v2 = diffusers_unet_conf["sample_size"] == 96 - if 'prediction_type' in diffusers_scheduler_conf: - v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' - - if v2: - if v_pred: - config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') - else: - config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') - - with open(config_path, 'r') as stream: - config = yaml.safe_load(stream) - - model_config_params = config['model']['params'] - clip_config = model_config_params['cond_stage_config'] - scale_factor = model_config_params['scale_factor'] - vae_config = model_config_params['first_stage_config'] - vae_config['scale_factor'] = scale_factor - model_config_params["unet_config"]["params"]["use_fp16"] = fp16 - - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") - text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") - - # Load models from safetensors if it exists, if it doesn't pytorch - if osp.exists(unet_path): - unet_state_dict = load_file(unet_path, device="cpu") - else: - unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") - unet_state_dict = torch.load(unet_path, map_location="cpu") - - if osp.exists(vae_path): - vae_state_dict = load_file(vae_path, device="cpu") - else: - vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") - vae_state_dict = torch.load(vae_path, map_location="cpu") - - if osp.exists(text_enc_path): - text_enc_dict = load_file(text_enc_path, device="cpu") - else: - text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - - # Convert the UNet model - unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} - - # Convert the VAE model - vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper - is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict - - if is_v20_model: - # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm - text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} - text_enc_dict = convert_text_enc_state_dict_v20(text_enc_dict) - text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} - else: - text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} - - # Put together new checkpoint - sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} - - clip = None - vae = None - - class WeightsLoader(torch.nn.Module): - pass - - w = WeightsLoader() - load_state_dict_to = [] - if output_vae: - vae = VAE(scale_factor=scale_factor, config=vae_config) - w.first_stage_model = vae.first_stage_model - load_state_dict_to = [w] - - if output_clip: - clip = CLIP(config=clip_config, embedding_directory=embedding_directory) - w.cond_stage_model = clip.cond_stage_model - load_state_dict_to = [w] - - model = instantiate_from_config(config["model"]) - model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) - - if fp16: - model = model.half() - - return ModelPatcher(model), clip, vae diff --git a/comfy/diffusers_load.py b/comfy/diffusers_load.py new file mode 100644 index 000000000..43877fb83 --- /dev/null +++ b/comfy/diffusers_load.py @@ -0,0 +1,111 @@ +import json +import os +import yaml + +import folder_paths +from comfy.ldm.util import instantiate_from_config +from comfy.sd import ModelPatcher, load_model_weights, CLIP, VAE +import os.path as osp +import re +import torch +from safetensors.torch import load_file, save_file +import diffusers_convert + +def load_diffusers(model_path, fp16=True, output_vae=True, output_clip=True, embedding_directory=None): + diffusers_unet_conf = json.load(open(osp.join(model_path, "unet/config.json"))) + diffusers_scheduler_conf = json.load(open(osp.join(model_path, "scheduler/scheduler_config.json"))) + + # magic + v2 = diffusers_unet_conf["sample_size"] == 96 + if 'prediction_type' in diffusers_scheduler_conf: + v_pred = diffusers_scheduler_conf['prediction_type'] == 'v_prediction' + + if v2: + if v_pred: + config_path = folder_paths.get_full_path("configs", 'v2-inference-v.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v2-inference.yaml') + else: + config_path = folder_paths.get_full_path("configs", 'v1-inference.yaml') + + with open(config_path, 'r') as stream: + config = yaml.safe_load(stream) + + model_config_params = config['model']['params'] + clip_config = model_config_params['cond_stage_config'] + scale_factor = model_config_params['scale_factor'] + vae_config = model_config_params['first_stage_config'] + vae_config['scale_factor'] = scale_factor + model_config_params["unet_config"]["params"]["use_fp16"] = fp16 + + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") + + # Convert the UNet model + unet_state_dict = diffusers_convert.convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = diffusers_convert.convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} + + # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper + is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict + + if is_v20_model: + # Need to add the tag 'transformer' in advance so we can knock it out from the final layer-norm + text_enc_dict = {"transformer." + k: v for k, v in text_enc_dict.items()} + text_enc_dict = diffusers_convert.convert_text_enc_state_dict_v20(text_enc_dict) + text_enc_dict = {"cond_stage_model.model." + k: v for k, v in text_enc_dict.items()} + else: + text_enc_dict = diffusers_convert.convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} + + # Put together new checkpoint + sd = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + + clip = None + vae = None + + class WeightsLoader(torch.nn.Module): + pass + + w = WeightsLoader() + load_state_dict_to = [] + if output_vae: + vae = VAE(scale_factor=scale_factor, config=vae_config) + w.first_stage_model = vae.first_stage_model + load_state_dict_to = [w] + + if output_clip: + clip = CLIP(config=clip_config, embedding_directory=embedding_directory) + w.cond_stage_model = clip.cond_stage_model + load_state_dict_to = [w] + + model = instantiate_from_config(config["model"]) + model = load_model_weights(model, sd, verbose=False, load_state_dict_to=load_state_dict_to) + + if fp16: + model = model.half() + + return ModelPatcher(model), clip, vae diff --git a/nodes.py b/nodes.py index 68010f040..90444a92c 100644 --- a/nodes.py +++ b/nodes.py @@ -17,7 +17,7 @@ import safetensors.torch sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) -import comfy.diffusers_convert +import comfy.diffusers_load import comfy.samplers import comfy.sample import comfy.sd @@ -377,7 +377,7 @@ class DiffusersLoader: model_path = path break - return comfy.diffusers_convert.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) + return comfy.diffusers_load.load_diffusers(model_path, fp16=comfy.model_management.should_use_fp16(), output_vae=output_vae, output_clip=output_clip, embedding_directory=folder_paths.get_folder_paths("embeddings")) class unCLIPCheckpointLoader: From a532888846809de7b8890e8beb10ea87edf39d7e Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Sun, 28 May 2023 02:02:09 -0400 Subject: [PATCH 32/48] Support VAEs in diffusers format. --- comfy/sd.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index c6be900ad..4df149fe1 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ from .t2i_adapter import adapter from . import utils from . import clip_vision from . import gligen +from . import diffusers_convert def load_model_weights(model, sd, verbose=False, load_state_dict_to=[]): m, u = model.load_state_dict(sd, strict=False) @@ -504,10 +505,16 @@ class VAE: if config is None: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} - self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss", ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(ddconfig, {'target': 'torch.nn.Identity'}, 4, monitor="val/rec_loss") else: - self.first_stage_model = AutoencoderKL(**(config['params']), ckpt_path=ckpt_path) + self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() + if ckpt_path is not None: + sd = utils.load_torch_file(ckpt_path) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) + self.first_stage_model.load_state_dict(sd, strict=False) + self.scale_factor = scale_factor if device is None: device = model_management.get_torch_device() From 23ffafeb5d4a25bb5e41c34c9f04a0733643892c Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" Date: Sun, 28 May 2023 23:31:40 +0900 Subject: [PATCH 33/48] typo fix: field name in error message --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index e8ab32cf9..26670239b 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1316,7 +1316,7 @@ export class ComfyApp { const nodeId = error.node_id const nodeType = error.node_type - return `Error occurred when executing ${nodeType}:\n\n${error.message}\n\n${traceback}` + return `Error occurred when executing ${nodeType}:\n\n${error.exception_message}\n\n${traceback}` } async queuePrompt(number, batchCount = 1) { From b9818eb910b6ce683c38602c9b8fbd3979d97aaf Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 02:48:50 -0400 Subject: [PATCH 34/48] Add route to get safetensors metadata: /view_metadata/loras?filename=lora.safetensors --- comfy/utils.py | 9 +++++++++ folder_paths.py | 2 ++ server.py | 25 ++++++++++++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/comfy/utils.py b/comfy/utils.py index 5ed9aaa02..4e84e870b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1,5 +1,6 @@ import torch import math +import struct def load_torch_file(ckpt, safe_load=False): if ckpt.lower().endswith(".safetensors"): @@ -50,6 +51,14 @@ def transformers_convert(sd, prefix_from, prefix_to, number): sd[k_to] = weights[shape_from*x:shape_from*(x + 1)] return sd +def safetensors_header(safetensors_path, max_size=100*1024*1024): + with open(safetensors_path, "rb") as f: + header = f.read(8) + length_of_header = struct.unpack(' max_size: + return None + return f.read(length_of_header) + def bislerp(samples, width, height): def slerp(b1, b2, r): '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC''' diff --git a/folder_paths.py b/folder_paths.py index 20b461c94..19245a617 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -126,11 +126,13 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths folders = folder_names_and_paths[folder_name] + filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: full_path = os.path.join(x, filename) if os.path.isfile(full_path): return full_path + return None def get_filename_list(folder_name): global folder_names_and_paths diff --git a/server.py b/server.py index c0f79cbd5..72c565a63 100644 --- a/server.py +++ b/server.py @@ -22,7 +22,7 @@ except ImportError: import mimetypes from comfy.cli_args import args - +import comfy.utils @web.middleware async def cache_control(request: web.Request, handler): @@ -257,6 +257,29 @@ class PromptServer(): return web.Response(status=404) + @routes.get("/view_metadata/{folder_name}") + async def view_metadata(request): + folder_name = request.match_info.get("folder_name", None) + if folder_name is None: + return web.Response(status=404) + if not "filename" in request.rel_url.query: + return web.Response(status=404) + + filename = request.rel_url.query["filename"] + if not filename.endswith(".safetensors"): + return web.Response(status=404) + + safetensors_path = folder_paths.get_full_path(folder_name, filename) + if safetensors_path is None: + return web.Response(status=404) + out = comfy.utils.safetensors_header(safetensors_path, max_size=1024*1024) + if out is None: + return web.Response(status=404) + dt = json.loads(out) + if not "__metadata__" in dt: + return web.Response(status=404) + return web.json_response(dt["__metadata__"]) + @routes.get("/prompt") async def get_prompt(request): return web.json_response(self.get_queue_info()) From 560e9f7a43242c51da2589a33f659ecd41914b20 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:29:00 -0400 Subject: [PATCH 35/48] Disable repo owner validation in update.py --- .ci/update_windows/update.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index c09f29a80..ef9374c44 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -41,7 +41,7 @@ def pull(repo, remote_name='origin', branch='master'): else: raise AssertionError('Unknown merge analysis result') - +pygit2.option(pygit2.GIT_OPT_SET_OWNER_VALIDATION, 0) repo = pygit2.Repository(str(sys.argv[1])) ident = pygit2.Signature('comfyui', 'comfy@ui') try: From 08abd838b82ea8d08a7e6f1484140d1694180381 Mon Sep 17 00:00:00 2001 From: "Lt.Dr.Data" Date: Tue, 30 May 2023 15:26:45 +0900 Subject: [PATCH 36/48] HOTFIX: Patched the conflict issue between the Combo Refresh feature and PrimitiveNodes. --- web/scripts/app.js | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/scripts/app.js b/web/scripts/app.js index 26670239b..64adc3e6a 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1424,6 +1424,11 @@ export class ComfyApp { const def = defs[node.type]; + // HOTFIX: The current patch is designed to prevent the rest of the code from breaking due to primitive nodes, + // and additional work is needed to consider the primitive logic in the refresh logic. + if(!def) + continue; + for(const widgetNum in node.widgets) { const widget = node.widgets[widgetNum] if(widget.type == "combo" && def["input"]["required"][widget.name] !== undefined) { From eb448dd8e18125b569bea9002f909769678a6c43 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 12:36:41 -0400 Subject: [PATCH 37/48] Auto load model in lowvram if not enough memory. --- comfy/model_management.py | 46 ++++++++++++++++++++++++--------------- comfy/sd.py | 18 +++++++++++++-- 2 files changed, 45 insertions(+), 19 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index c15323219..10a706793 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -15,9 +15,8 @@ vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM total_vram = 0 -total_vram_available_mb = -1 -accelerate_enabled = False +lowvram_available = True xpu_available = False directml_enabled = False @@ -31,11 +30,12 @@ if args.directml is not None: directml_device = torch_directml.device(device_index) print("Using directml with device:", torch_directml.device_name(device_index)) # torch_directml.disable_tiled_resources(True) + lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. try: import torch if directml_enabled: - total_vram = 4097 #TODO + pass #TODO else: try: import intel_extension_for_pytorch as ipex @@ -46,7 +46,7 @@ try: total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024) total_ram = psutil.virtual_memory().total / (1024 * 1024) if not args.normalvram and not args.cpu: - if total_vram <= 4096: + if lowvram_available and total_vram <= 4096: print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram") set_vram_to = VRAMState.LOW_VRAM elif total_vram > total_ram * 1.1 and total_vram > 14336: @@ -92,6 +92,7 @@ if ENABLE_PYTORCH_ATTENTION: if args.lowvram: set_vram_to = VRAMState.LOW_VRAM + lowvram_available = True elif args.novram: set_vram_to = VRAMState.NO_VRAM elif args.highvram: @@ -103,18 +104,18 @@ if args.force_fp32: FORCE_FP32 = True -if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + +if lowvram_available: try: import accelerate - accelerate_enabled = True - vram_state = set_vram_to + if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM): + vram_state = set_vram_to except Exception as e: import traceback print(traceback.format_exc()) - print("ERROR: COULD NOT ENABLE LOW VRAM MODE.") + print("ERROR: LOW VRAM MODE NEEDS accelerate.") + lowvram_available = False - total_vram_available_mb = (total_vram - 1024) // 2 - total_vram_available_mb = int(max(256, total_vram_available_mb)) try: if torch.backends.mps.is_available(): @@ -199,22 +200,33 @@ def load_model_gpu(model): model.unpatch_model() raise e - model.model_patches_to(get_torch_device()) + torch_dev = get_torch_device() + model.model_patches_to(torch_dev) + + vram_set_state = vram_state + if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): + model_size = model.model_size() + current_free_mem = get_free_memory(torch_dev) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary + vram_set_state = VRAMState.LOW_VRAM + current_loaded_model = model - if vram_state == VRAMState.CPU: + + if vram_set_state == VRAMState.CPU: pass - elif vram_state == VRAMState.MPS: + elif vram_set_state == VRAMState.MPS: mps_device = torch.device("mps") real_model.to(mps_device) pass - elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: model_accelerated = False real_model.to(get_torch_device()) else: - if vram_state == VRAMState.NO_VRAM: + if vram_set_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True diff --git a/comfy/sd.py b/comfy/sd.py index 4df149fe1..ce17994f7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -286,15 +286,29 @@ def model_lora_keys(model, key_map={}): return key_map + class ModelPatcher: - def __init__(self, model): + def __init__(self, model, size=0): + self.size = size self.model = model self.patches = [] self.backup = {} self.model_options = {"transformer_options":{}} + self.model_size() + + def model_size(self): + if self.size > 0: + return self.size + model_sd = self.model.state_dict() + size = 0 + for k in model_sd: + t = model_sd[k] + size += t.nelement() * t.element_size() + self.size = size + return size def clone(self): - n = ModelPatcher(self.model) + n = ModelPatcher(self.model, self.size) n.patches = self.patches[:] n.model_options = copy.deepcopy(self.model_options) return n From 2260802d90c41f1475a7bf2960aa018dc25f1001 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 30 May 2023 16:44:09 -0400 Subject: [PATCH 38/48] Check if folder_name is valid instead of just throwing exception. --- folder_paths.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/folder_paths.py b/folder_paths.py index 19245a617..fc37e52c7 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -125,6 +125,8 @@ def filter_files_extensions(files, extensions): def get_full_path(folder_name, filename): global folder_names_and_paths + if folder_name not in folder_names_and_paths: + return None folders = folder_names_and_paths[folder_name] filename = os.path.relpath(os.path.join("/", filename), "/") for x in folders[0]: From 04f4fba013da1f556fc310235d5a30c2bfe682e8 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:01:49 -0500 Subject: [PATCH 39/48] Fix litegraph dialog CSS --- web/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/index.html b/web/index.html index bb79433ce..da0adb6c2 100644 --- a/web/index.html +++ b/web/index.html @@ -14,5 +14,5 @@ window.graph = app.graph; - + From 468c27afea29928d7d9fcd208e1137a36118ad13 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Tue, 30 May 2023 16:06:17 -0500 Subject: [PATCH 40/48] Fix litegraph dialog z-index/font --- web/style.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/web/style.css b/web/style.css index 87f096e14..db82887c3 100644 --- a/web/style.css +++ b/web/style.css @@ -289,6 +289,11 @@ button.comfy-queue-btn { /* Context menu */ +.litegraph .dialog { + z-index: 1; + font-family: Arial; +} + .litegraph .litemenu-entry.has_submenu { position: relative; padding-right: 20px; From 8ef197f02852b65509d6ebe06df8794b96a07f2f Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 29 May 2023 11:26:57 -0400 Subject: [PATCH 41/48] Keep list of filenames and only refresh it when something changes. --- folder_paths.py | 47 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/folder_paths.py b/folder_paths.py index fc37e52c7..f3d1b8773 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -31,6 +31,8 @@ output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "ou temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") +filename_list_cache = {} + if not os.path.exists(input_directory): os.makedirs(input_directory) @@ -111,12 +113,18 @@ def get_folder_paths(folder_name): return folder_names_and_paths[folder_name][0][:] def recursive_search(directory): + if not os.path.isdir(directory): + return [], {} result = [] + dirs = {directory: os.path.getmtime(directory)} for root, subdir, file in os.walk(directory, followlinks=True): for filepath in file: #we os.path,join directory with a blank string to generate a path separator at the end. result.append(os.path.join(root, filepath).replace(os.path.join(directory,''),'')) - return result + for d in subdir: + path = os.path.join(root, d) + dirs[path] = os.path.getmtime(path) + return result, dirs def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions, files))) @@ -136,13 +144,44 @@ def get_full_path(folder_name, filename): return None -def get_filename_list(folder_name): +def get_filename_list_(folder_name): global folder_names_and_paths output_list = set() folders = folder_names_and_paths[folder_name] + output_folders = {} for x in folders[0]: - output_list.update(filter_files_extensions(recursive_search(x), folders[1])) - return sorted(list(output_list)) + files, folders_all = recursive_search(x) + output_list.update(filter_files_extensions(files, folders[1])) + output_folders = {**output_folders, **folders_all} + + return (sorted(list(output_list)), output_folders) + +def cached_filename_list_(folder_name): + global filename_list_cache + global folder_names_and_paths + if folder_name not in filename_list_cache: + return None + out = filename_list_cache[folder_name] + for x in out[1]: + time_modified = out[1][x] + folder = x + if os.path.getmtime(folder) != time_modified: + return None + + folders = folder_names_and_paths[folder_name] + for x in folders[0]: + if x not in out[1]: + return None + + return out + +def get_filename_list(folder_name): + out = cached_filename_list_(folder_name) + if out is None: + out = get_filename_list_(folder_name) + global filename_list_cache + filename_list_cache[folder_name] = out + return out[0] def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 1f34bf08f06550fb2f041188b5a01d395240be17 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 22:01:25 +0900 Subject: [PATCH 42/48] To support dynamic custom loading, separate the node registration process based on the defs in the registerNodes function. --- web/scripts/app.js | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 64adc3e6a..9ecad8489 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,6 +1010,11 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); + this.registerNodesFromDefs(defs); + await this.#invokeExtensionsAsync("registerCustomNodes"); + } + + async registerNodesFromDefs(defs) { await this.#invokeExtensionsAsync("addCustomNodeDefs", defs); // Generate list of known widgets @@ -1082,8 +1087,6 @@ export class ComfyApp { LiteGraph.registerNodeType(nodeId, node); node.category = nodeData.category; } - - await this.#invokeExtensionsAsync("registerCustomNodes"); } /** From 8e8d6070f2e80aff0200bb3ad0f31716a98d5739 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Wed, 31 May 2023 23:26:56 +0900 Subject: [PATCH 43/48] race condition patch --- web/scripts/app.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/scripts/app.js b/web/scripts/app.js index 9ecad8489..8a9c7ca49 100644 --- a/web/scripts/app.js +++ b/web/scripts/app.js @@ -1010,7 +1010,7 @@ export class ComfyApp { const app = this; // Load node definitions from the backend const defs = await api.getNodeDefs(); - this.registerNodesFromDefs(defs); + await this.registerNodesFromDefs(defs); await this.#invokeExtensionsAsync("registerCustomNodes"); } From 03da8a34265bb333d03a51d7503697b5ede9b335 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 31 May 2023 13:03:24 -0400 Subject: [PATCH 44/48] This is useless for inference. --- comfy/sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/sd.py b/comfy/sd.py index ce17994f7..fa7bd8d32 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -743,7 +743,7 @@ def load_controlnet(ckpt_path, model=None): use_spatial_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) else: @@ -760,7 +760,7 @@ def load_controlnet(ckpt_path, model=None): use_linear_in_transformer=True, transformer_depth=1, context_dim=context_dim, - use_checkpoint=True, + use_checkpoint=False, legacy=False, use_fp16=use_fp16) if pth: @@ -1045,7 +1045,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o } unet_config = { - "use_checkpoint": True, + "use_checkpoint": False, "image_size": 32, "out_channels": 4, "attention_resolutions": [ From d200fa131420a8871633b7321664db419aab2712 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 31 May 2023 19:00:01 -0500 Subject: [PATCH 45/48] Prevent callers from mutating folder lists --- folder_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/folder_paths.py b/folder_paths.py index f3d1b8773..e179a28d4 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -181,7 +181,7 @@ def get_filename_list(folder_name): out = get_filename_list_(folder_name) global filename_list_cache filename_list_cache[folder_name] = out - return out[0] + return list(out[0]) def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def map_filename(filename): From 94680732d32b4b540251c122aee36df8d37266e1 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 03:52:51 -0400 Subject: [PATCH 46/48] Empty cache on mps. --- comfy/model_management.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 10a706793..60bcd786b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -389,7 +389,10 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - if xpu_available: + global vram_state + if vram_state == VRAMState.MPS: + torch.mps.empty_cache() + elif xpu_available: torch.xpu.empty_cache() elif torch.cuda.is_available(): if torch.version.cuda: #This seems to make things worse on ROCm so I only do it for cuda From 5c38958e49efd11b5234cb5ff472d752698c5090 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 1 Jun 2023 04:04:35 -0400 Subject: [PATCH 47/48] Tweak lowvram model memory so it's closer to what it was before. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 60bcd786b..e9af7f3a7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -207,7 +207,7 @@ def load_model_gpu(model): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): model_size = model.model_size() current_free_mem = get_free_memory(torch_dev) - lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.2 )) + lowvram_model_memory = int(max(256 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 )) if model_size > (current_free_mem - (512 * 1024 * 1024)): #only switch to lowvram if really necessary vram_set_state = VRAMState.LOW_VRAM From 1bbd3f7fe16e6637bba232059d004a5fe7804a30 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 1 Jun 2023 22:15:06 -0500 Subject: [PATCH 48/48] Send back prompt number from prompt/ endpoint --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 72c565a63..0b64df147 100644 --- a/server.py +++ b/server.py @@ -361,7 +361,7 @@ class PromptServer(): prompt_id = str(uuid.uuid4()) outputs_to_execute = valid[2] self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute)) - return web.json_response({"prompt_id": prompt_id}) + return web.json_response({"prompt_id": prompt_id, "number": number}) else: print("invalid prompt:", valid[1]) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)