diff --git a/execution.py b/execution.py index 3ca551db6..864d9943f 100644 --- a/execution.py +++ b/execution.py @@ -247,7 +247,9 @@ def validate_inputs(prompt, item): if isinstance(type_input, list): if val not in type_input: - return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) + # bypass validation on special output + if not val.endswith(" [OUT]") and not val.endswith(" [TEMP]"): + return (False, "Value not in list. {}, {}: {} not in {}".format(class_type, x, val, type_input)) return (True, "") def validate_prompt(prompt): diff --git a/nodes.py b/nodes.py index 4eddfafe9..f72521a4d 100644 --- a/nodes.py +++ b/nodes.py @@ -730,7 +730,7 @@ class SaveImage: "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } - WIDGET_TYPES = {"send to img": ("IMAGESEND", )} + WIDGET_TYPES = {"send to img": ("IMAGESEND", "OUT")} RETURN_TYPES = () FUNCTION = "save_images" @@ -810,8 +810,12 @@ class PreviewImage(SaveImage): "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + WIDGET_TYPES = {"send to img": ("IMAGESEND", "TEMP")} + class LoadImage: input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @classmethod def INPUT_TYPES(s): if not os.path.exists(s.input_dir): @@ -826,8 +830,17 @@ class LoadImage: RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" + + def get_image_path(self, image): + if image.endswith(" [OUT]"): + return os.path.join(self.output_dir, image[:-6]) + elif image.endswith(" [TEMP]"): + return os.path.join(self.temp_dir, image[:-7]) + else: + return os.path.join(self.input_dir, image) + def load_image(self, image): - image_path = os.path.join(self.input_dir, image) + image_path = LoadImage.get_image_path(self, image) i = Image.open(image_path) image = i.convert("RGB") image = np.array(image).astype(np.float32) / 255.0 @@ -841,7 +854,7 @@ class LoadImage: @classmethod def IS_CHANGED(s, image): - image_path = os.path.join(s.input_dir, image) + image_path = LoadImage.get_image_path(s, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) @@ -849,6 +862,8 @@ class LoadImage: class LoadImageMask: input_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") + output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") + temp_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") @classmethod def INPUT_TYPES(s): return {"required": @@ -863,7 +878,7 @@ class LoadImageMask: RETURN_TYPES = ("MASK",) FUNCTION = "load_image" def load_image(self, image, channel): - image_path = os.path.join(self.input_dir, image) + image_path = LoadImage.get_image_path(self, image) i = Image.open(image_path) mask = None c = channel[0].upper() @@ -878,7 +893,7 @@ class LoadImageMask: @classmethod def IS_CHANGED(s, image, channel): - image_path = os.path.join(s.input_dir, image) + image_path = LoadImage.get_image_path(s, image) m = hashlib.sha256() with open(image_path, 'rb') as f: m.update(f.read()) diff --git a/server.py b/server.py index 7502e5170..1ad3727c1 100644 --- a/server.py +++ b/server.py @@ -7,7 +7,6 @@ import execution import uuid import json import glob -from shutil import copyfile try: import aiohttp @@ -88,17 +87,6 @@ class PromptServer(): files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True) return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))) - @routes.get("/image/output_to_input/{name}") - async def copy_output_to_input_image(request): - name = request.match_info["name"] - - src_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output", name) - dest_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input", name) - - copyfile(src_dir,dest_dir) - - return web.Response(status=200) - @routes.post("/upload/image") async def upload_image(request): upload_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/web/scripts/widgets.js b/web/scripts/widgets.js index 0ce1559dc..a66ef745c 100644 --- a/web/scripts/widgets.js +++ b/web/scripts/widgets.js @@ -47,7 +47,7 @@ function seedWidget(node, inputName, inputData) { } function imagesendWidget(node, inputName, inputData, app) { - function showImage(node,uploadWidget,name) { + function showImage(node,uploadWidget,name,type) { // Position the image somewhere sensible if (!node.imageOffset) { node.imageOffset = uploadWidget.last_y ? uploadWidget.last_y + 50 : 100; @@ -58,7 +58,13 @@ function imagesendWidget(node, inputName, inputData, app) { node.imgs = [img]; app.graph.setDirtyCanvas(true); }; - img.src = `/view?filename=${name}&type=input`; + + if(type == "OUT") + img.src = `/view?filename=${name}&type=output`; + else if(type == "TEMP") + img.src = `/view?filename=${name}&type=temp`; + else + img.src = `/view?filename=${name}&type=input`; } async function callback() { @@ -75,15 +81,9 @@ function imagesendWidget(node, inputName, inputData, app) { const recvWidget = n.widgets.find((w) => w.name === "recv img"); if(recvWidget.value == "enable") { - // copy current node image to 'recv img' enabled node - - if(!copied) { - await api.sendOutputToInputImage(image_name); - } - - imageWidget.value = image_name; + imageWidget.value = image_name + ` [${inputData[1]}]`; const thatImageWidget = n.widgets.find((w) => w.value === "image"); - await showImage(n,thatImageWidget,image_name); + await showImage(n,thatImageWidget,image_name,inputData[1]); } } }