ComfyUI/app/subgraph_manager.py

88 lines
3.1 KiB
Python

from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def get_custom_node_subgraphs(self, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
info: CustomNodeSubgraphEntryInfo = {
"node_pack": folder
}
path = f"{folder}/{file.split(folder)[-1]}"
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{path}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": path,
"info": info
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str):
subgraphs = await self.get_custom_node_subgraphs()
return subgraphs.get(id, None)
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(subgraphs_dict)
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id)
return web.json_response(subgraph)