basic subflow ui functionality

This commit is contained in:
Sammy Franklin 2023-10-07 00:50:03 -07:00
parent 0f39cfe403
commit e268692351
7 changed files with 133 additions and 1 deletions

1
.gitignore vendored
View File

@ -5,6 +5,7 @@ __pycache__/
!/input/example.png
/models/
/temp/
/subflows/
/custom_nodes/
!custom_nodes/example_node.py.example
extra_model_paths.yaml

View File

@ -0,0 +1,29 @@
import folder_paths
import json
import os.path as osp
class Subflow:
@classmethod
def INPUT_TYPES(s):
return {"required": { "subflow_name": (folder_paths.get_filename_list("subflows"), ), }}
RETURN_TYPES = ()
FUNCTION = "exec_subflow"
CATEGORY = "loaders"
def exec_subflow(self, subflow_name):
subflow_path = folder_paths.get_full_path("subflows", subflow_name)
with open(subflow_path) as f:
if osp.splitext(subflow_path)[1] == ".json":
subflow_data = json.load(f)
return subflow_data
return None
NODE_CLASS_MAPPINGS = {
"Subflow": Subflow,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Subflow": "Load Subflow"
}

View File

@ -2,6 +2,7 @@ import os
import time
supported_pt_extensions = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors'])
supported_subflow_extensions = set(['.json', '.png'])
folder_names_and_paths = {}
@ -29,15 +30,21 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
folder_names_and_paths["subflows"] = ([os.path.join(base_path, "subflows")], supported_subflow_extensions)
output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output")
temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp")
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
subflows_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "subflows")
filename_list_cache = {}
if not os.path.exists(input_directory):
os.makedirs(input_directory)
if not os.path.exists(subflows_directory):
os.makedirs(subflows_directory)
def set_output_directory(output_dir):
global output_directory
output_directory = output_dir
@ -58,6 +65,9 @@ def get_input_directory():
global input_directory
return input_directory
def get_subflows_directory():
global subflows_directory
return subflows_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name):
@ -67,6 +77,8 @@ def get_directory_by_type(type_name):
return get_temp_directory()
if type_name == "input":
return get_input_directory()
if type_name == "subflows":
return get_subflows_directory()
return None
@ -82,6 +94,9 @@ def annotated_filepath(name):
elif name.endswith("[temp]"):
base_dir = get_temp_directory()
name = name[:-7]
elif name.endswith("[subflows]"):
base_dir = get_subflows_directory()
name = name[:-11]
else:
return name, None

View File

@ -1795,7 +1795,8 @@ def init_custom_nodes():
"nodes_clip_sdxl.py",
"nodes_canny.py",
"nodes_freelunch.py",
"nodes_custom_sampler.py"
"nodes_custom_sampler.py",
"nodes_subflow.py",
]
for node_file in extras_files:

View File

@ -1,4 +1,5 @@
import os
import os.path as osp
import sys
import asyncio
import traceback
@ -153,6 +154,8 @@ class PromptServer():
type_dir = folder_paths.get_temp_directory()
elif dir_type == "output":
type_dir = folder_paths.get_output_directory()
elif dir_type == "subflows":
type_dir = folder_paths.get_subflows_directory()
return type_dir, dir_type
@ -516,6 +519,21 @@ class PromptServer():
return web.Response(status=200)
@routes.get("/subflows/{subflow_name}")
async def get_subflow(request):
subflow_name = request.match_info.get("subflow_name", None)
if subflow_name != None:
subflow_path = folder_paths.get_full_path("subflows", subflow_name)
ext = osp.splitext(subflow_path)[1]
with open(subflow_path) as f:
if ext == ".json":
subflow_data = json.load(f)
return web.json_response({"subflow": subflow_data}, status=200)
elif ext == ".png":
return web.json_response({"error": "todo", "node_errors": []}, status=400)
return web.json_response({"error": "no subflow_name provided", "node_errors": []}, status=400)
def add_routes(self):
self.app.add_routes(self.routes)

View File

@ -0,0 +1,54 @@
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
app.registerExtension({
name: "Comfy.Subflow",
async nodeCreated(node) {
if (!node.widgets) return;
if (node.widgets[0].name !== "subflow_name") return;
const refreshPins = (subflowNodes) => {
// remove all existing pins
const numInputs = node.inputs.length;
const numOutputs = node.outputs.length;
for(let i = numInputs-1; i > -1; i--) {
node.removeInput(i);
}
for(let i = numOutputs-1; i > -1; i--) {
node.removeOutput(i);
}
for (const subflowNode of subflowNodes) {
const exports = subflowNode.properties.exports;
if (exports) {
for (const inputRef of exports.inputs) {
const input = subflowNode.inputs.find(q => q.name === inputRef);
if (!input) continue;
node.addInput(input.name, input.type);
}
for (const outputRef of exports.outputs) {
const output = subflowNode.outputs.find(q => q.name === outputRef);
if (!output) continue;
node.addOutput(output.name, output.type);
}
}
}
};
node.onConfigure = async function () {
const subflowData = await api.getSubflow(node.widgets[0].value);
if (subflowData.subflow) {
refreshPins(subflowData.subflow.nodes);
}
};
node.widgets[0].callback = async function (subflowName) {
const subflowData = await api.getSubflow(subflowName);
if (subflowData.subflow) {
refreshPins(subflowData.subflow.nodes);
}
};
}
});

View File

@ -264,6 +264,20 @@ class ComfyApi extends EventTarget {
}
}
/**
* Gets the subflow json data
* @returns Prompt history including node outputs
*/
async getSubflow(subflowName) {
try {
const res = await this.fetchApi(`/subflows/${subflowName}`);
return await res.json();
} catch (error) {
console.error(error);
return { };
}
}
/**
* Gets system & device stats
* @returns System stats such as python version, OS, per device info