from __future__ import annotations from typing import TypedDict import os import folder_paths import glob from aiohttp import web import hashlib class Source: custom_node = "custom_node" templates = "templates" class SubgraphEntry(TypedDict): source: str """ Source of subgraph - custom_nodes vs templates. """ path: str """ Relative path of the subgraph file. For custom nodes, will be the relative directory like /subgraphs/.json """ name: str """ Name of subgraph file. """ info: CustomNodeSubgraphEntryInfo """ Additional info about subgraph; in the case of custom_nodes, will contain nodepack name """ data: str class CustomNodeSubgraphEntryInfo(TypedDict): node_pack: str """Node pack name.""" class SubgraphManager: def __init__(self): self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]: """Create a subgraph entry from a file path. Expects normalized path (forward slashes).""" entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() entry: SubgraphEntry = { "source": source, "name": os.path.splitext(os.path.basename(file))[0], "path": file, "info": {"node_pack": node_pack}, } return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): with open(entry['path'], 'r') as f: entry['data'] = f.read() return entry async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: if entry is None: return None entry = entry.copy() entry.pop('path', None) if remove_data: entry.pop('data', None) return entry async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: entries = entries.copy() for key in list(entries.keys()): entries[key] = await self.sanitize_entry(entries[key], remove_data) return entries async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): """Load subgraphs from custom nodes.""" if not force_reload and self.cached_custom_node_subgraphs is not None: return self.cached_custom_node_subgraphs subgraphs_dict: dict[SubgraphEntry] = {} for folder in folder_paths.get_folder_paths("custom_nodes"): pattern = os.path.join(folder, "*/subgraphs/*.json") for file in glob.glob(pattern): file = file.replace('\\', '/') node_pack = "custom_nodes." + file.split('/')[-3] entry_id, entry = self._create_entry(file, Source.custom_node, node_pack) subgraphs_dict[entry_id] = entry self.cached_custom_node_subgraphs = subgraphs_dict return subgraphs_dict async def get_blueprint_subgraphs(self, force_reload=False): """Load subgraphs from the blueprints directory.""" if not force_reload and self.cached_blueprint_subgraphs is not None: return self.cached_blueprint_subgraphs subgraphs_dict: dict[SubgraphEntry] = {} blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints') if os.path.exists(blueprints_dir): for file in glob.glob(os.path.join(blueprints_dir, "*.json")): file = file.replace('\\', '/') entry_id, entry = self._create_entry(file, Source.templates, "comfyui") subgraphs_dict[entry_id] = entry self.cached_blueprint_subgraphs = subgraphs_dict return subgraphs_dict async def get_all_subgraphs(self, loadedModules, force_reload=False): """Get all subgraphs from all sources (custom nodes and blueprints).""" custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload) blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload) return {**custom_node_subgraphs, **blueprint_subgraphs} async def get_subgraph(self, id: str, loadedModules): """Get a specific subgraph by ID from any source.""" entry = (await self.get_all_subgraphs(loadedModules)).get(id) if entry is not None and entry.get('data') is None: await self.load_entry_data(entry) return entry def add_routes(self, routes, loadedModules): @routes.get("/global_subgraphs") async def get_global_subgraphs(request): subgraphs_dict = await self.get_all_subgraphs(loadedModules) return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) @routes.get("/global_subgraphs/{id}") async def get_global_subgraph(request): id = request.match_info.get("id", None) subgraph = await self.get_subgraph(id, loadedModules) return web.json_response(await self.sanitize_entry(subgraph))