dev: refactor; populate models in more nodes; use Pydantic in endpoints for input validation

This commit is contained in:
bigcat88 2025-08-23 20:14:22 +03:00
parent f92307cd4c
commit 5c1b5973ac
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
8 changed files with 141 additions and 280 deletions

View File

@ -124,8 +124,12 @@ def upgrade() -> None:
{"name": "upscale", "tag_type": "system"},
{"name": "diffusion-model", "tag_type": "system"},
{"name": "hypernetwork", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
# TODO: decide what to do with: style_models, diffusers, gligen, photomaker, classifiers
{"name": "vae-approx", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "style-model", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
# TODO: decide what to do with: photomaker, classifiers
],
)

View File

@ -1,8 +1,10 @@
import json
from typing import Sequence
from aiohttp import web
from typing import Optional
from app import assets_manager
from aiohttp import web
from pydantic import ValidationError
from .. import assets_manager
from .schemas_in import ListAssetsQuery, UpdateAssetBody
ROUTES = web.RouteTableDef()
@ -10,38 +12,22 @@ ROUTES = web.RouteTableDef()
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
q = request.rel_url.query
query_dict = dict(request.rel_url.query)
include_tags: Sequence[str] = _parse_csv_tags(q.get("include_tags"))
exclude_tags: Sequence[str] = _parse_csv_tags(q.get("exclude_tags"))
name_contains = q.get("name_contains")
# Optional JSON metadata filter (top-level key equality only for now)
metadata_filter = None
raw_meta = q.get("metadata_filter")
if raw_meta:
try:
metadata_filter = json.loads(raw_meta)
if not isinstance(metadata_filter, dict):
metadata_filter = None
except Exception:
# Silently ignore malformed JSON for first iteration; could 400 in future
metadata_filter = None
limit = _parse_int(q.get("limit"), default=20, lo=1, hi=100)
offset = _parse_int(q.get("offset"), default=0, lo=0, hi=10_000_000)
sort = q.get("sort", "created_at")
order = q.get("order", "desc")
try:
q = ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = await assets_manager.list_assets(
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
)
return web.json_response(payload)
@ -55,29 +41,18 @@ async def update_asset(request: web.Request) -> web.Response:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
payload = await request.json()
body = UpdateAssetBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
name = payload.get("name", None)
tags = payload.get("tags", None)
user_metadata = payload.get("user_metadata", None)
if name is None and tags is None and user_metadata is None:
return _error_response(400, "NO_FIELDS", "Provide at least one of: name, tags, user_metadata.")
if tags is not None and (not isinstance(tags, list) or not all(isinstance(t, str) for t in tags)):
return _error_response(400, "INVALID_TAGS", "Field 'tags' must be an array of strings.")
if user_metadata is not None and not isinstance(user_metadata, dict):
return _error_response(400, "INVALID_METADATA", "Field 'user_metadata' must be an object.")
try:
result = await assets_manager.update_asset(
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
@ -90,21 +65,9 @@ def register_assets_routes(app: web.Application) -> None:
app.add_routes(ROUTES)
def _parse_csv_tags(raw: str | None) -> list[str]:
if not raw:
return []
return [t.strip() for t in raw.split(",") if t.strip()]
def _parse_int(qval: str | None, default: int, lo: int, hi: int) -> int:
if not qval:
return default
try:
v = int(qval)
except Exception:
return default
return max(lo, min(hi, v))
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
def _error_response(status: int, code: str, message: str, details: Optional[dict] = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.errors()})

66
app/api/schemas_in.py Normal file
View File

@ -0,0 +1,66 @@
from __future__ import annotations
from typing import Any, Optional, Literal
from pydantic import BaseModel, Field, field_validator, model_validator, conint
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: Optional[str] = None
# Accept either a JSON string (query param) or a dict
metadata_filter: Optional[dict[str, Any]] = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
import json
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class UpdateAssetBody(BaseModel):
name: Optional[str] = None
tags: Optional[list[str]] = None
user_metadata: Optional[dict[str, Any]] = None
@model_validator(mode="after")
def _at_least_one(self):
if self.name is None and self.tags is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, tags, user_metadata.")
if self.tags is not None:
if not isinstance(self.tags, list) or not all(isinstance(t, str) for t in self.tags):
raise ValueError("Field 'tags' must be an array of strings.")
return self

View File

@ -2,6 +2,9 @@ import os
from datetime import datetime, timezone
from typing import Optional, Sequence
from comfy.cli_args import args
from comfy_api.internal import async_to_sync
from .database.db import create_session
from .storage import hashing
from .database.services import (
@ -14,9 +17,11 @@ from .database.services import (
)
def get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None:
if not args.disable_model_processing:
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset, tags=tags, file_name=file_name, file_path=file_path
)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
@ -28,7 +33,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No
- This function ensures the identity row and seeds mtime in asset_locator_state.
"""
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = get_size_mtime_ns(abs_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
return
@ -146,3 +151,8 @@ def _safe_sort_field(requested: str | None) -> str:
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))

View File

@ -1,195 +0,0 @@
from __future__ import annotations
import os
import base64
import json
import time
import logging
import folder_paths
import glob
import comfy.utils
from aiohttp import web
from PIL import Image
from io import BytesIO
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
class ModelFileManager:
def __init__(self) -> None:
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
return self.cache.get(key, default)
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
self.cache[key] = value
def clear_cache(self):
self.cache.clear()
def add_routes(self, routes):
# NOTE: This is an experiment to replace `/models`
@routes.get("/experiment/models")
async def get_model_folders(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
folder_black_list = ["configs", "custom_nodes"]
output_folders: list[dict] = []
for folder in model_types:
if folder in folder_black_list:
continue
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
return web.json_response(output_folders)
# NOTE: This is an experiment to replace `/models/{folder}`
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
async def get_model_preview(request):
folder_name = request.match_info.get("folder", None)
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]
folder = folders[0][path_index]
full_filename = os.path.join(folder, filename)
previews = self.get_model_previews(full_filename)
default_preview = previews[0] if len(previews) > 0 else None
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
return web.Response(status=404)
try:
with Image.open(default_preview) as img:
img_bytes = BytesIO()
img.save(img_bytes, format="WEBP")
img_bytes.seek(0)
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
except:
return web.Response(status=404)
def get_model_file_list(self, folder_name: str):
folder_name = map_legacy(folder_name)
folders = folder_paths.folder_names_and_paths[folder_name]
output_list: list[dict] = []
for index, folder in enumerate(folders[0]):
if not os.path.isdir(folder):
continue
out = self.cache_model_file_list_(folder)
if out is None:
out = self.recursive_search_models_(folder, index)
self.set_cache(folder, out)
output_list.extend(out[0])
return output_list
def cache_model_file_list_(self, folder: str):
model_file_list_cache = self.get_cache(folder)
if model_file_list_cache is None:
return None
if not os.path.isdir(folder):
return None
if os.path.getmtime(folder) != model_file_list_cache[1]:
return None
for x in model_file_list_cache[1]:
time_modified = model_file_list_cache[1][x]
folder = x
if os.path.getmtime(folder) != time_modified:
return None
return model_file_list_cache
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
if not os.path.isdir(directory):
return [], {}, time.perf_counter()
excluded_dir_names = [".git"]
# TODO use settings
include_hidden_files = False
result: list[str] = []
dirs: dict[str, float] = {}
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
if not include_hidden_files:
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
filenames = [f for f in filenames if not f.startswith(".")]
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
for file_name in filenames:
try:
full_path = os.path.join(dirpath, file_name)
relative_path = os.path.relpath(full_path, directory)
# Get file metadata
file_info = {
"name": relative_path,
"pathIndex": pathIndex,
"modified": os.path.getmtime(full_path), # Add modification time
"created": os.path.getctime(full_path), # Add creation time
"size": os.path.getsize(full_path) # Add file size
}
result.append(file_info)
except Exception as e:
logging.warning(f"Warning: Unable to access {file_name}. Error: {e}. Skipping this file.")
continue
for d in subdirs:
path: str = os.path.join(dirpath, d)
try:
dirs[path] = os.path.getmtime(path)
except FileNotFoundError:
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
continue
return result, dirs, time.perf_counter()
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
return []
basename = os.path.splitext(filepath)[0]
match_files = glob.glob(f"{basename}.*", recursive=False)
image_files = filter_files_content_types(match_files, "image")
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
safetensors_metadata = {}
result: list[str | BytesIO] = []
for filename in image_files:
_basename = os.path.splitext(filename)[0]
if _basename == basename:
result.append(filename)
if _basename == f"{basename}.preview":
result.append(filename)
if safetensors_file:
safetensors_filepath = os.path.join(dirname, safetensors_file)
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
if header:
safetensors_metadata = json.loads(header)
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
if safetensors_images:
safetensors_images = json.loads(safetensors_images)
for image in safetensors_images:
result.append(BytesIO(base64.b64decode(image)))
return result
def __exit__(self, exc_type, exc_value, traceback):
self.clear_cache()

View File

@ -212,7 +212,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable model file processing, e.g. computing hashes and extracting metadata.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@ -28,10 +28,10 @@ import comfy.sd
import comfy.utils
import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import async_to_sync, register_versions, ComfyAPIWithVersion
from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension
from app.assets_manager import add_local_asset
from app.assets_manager import populate_db_with_asset
import comfy.clip_vision
@ -555,7 +555,9 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
return comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
out = comfy.sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out
class CheckpointLoaderSimple:
@classmethod
@ -577,6 +579,7 @@ class CheckpointLoaderSimple:
def load_checkpoint(self, ckpt_name):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out[:3]
class DiffusersLoader:
@ -619,6 +622,7 @@ class unCLIPCheckpointLoader:
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
populate_db_with_asset(["models", "checkpoint"], ckpt_name, ckpt_path)
return out
class CLIPSetLastLayer:
@ -677,6 +681,7 @@ class LoraLoader:
self.loaded_lora = (lora_path, lora)
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
populate_db_with_asset(["models", "lora"], lora_name, lora_path)
return (model_lora, clip_lora)
class LoraLoaderModelOnly(LoraLoader):
@ -741,11 +746,15 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
encoder_path = folder_paths.get_full_path_or_raise("vae_approx", encoder)
populate_db_with_asset(["models", "vae-approx", "encoder"], name, encoder_path)
enc = comfy.utils.load_torch_file(encoder_path)
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
dec = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
decoder_path = folder_paths.get_full_path_or_raise("vae_approx", decoder)
populate_db_with_asset(["models", "vae-approx", "decoder"], name, decoder_path)
dec = comfy.utils.load_torch_file(decoder_path)
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
@ -778,9 +787,7 @@ class VAELoader:
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset, tags=["models", "vae"], file_name=vae_name, file_path=vae_path
)
populate_db_with_asset(["models", "vae"], vae_name, vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,)
@ -800,6 +807,7 @@ class ControlNetLoader:
controlnet = comfy.controlnet.load_controlnet(controlnet_path)
if controlnet is None:
raise RuntimeError("ERROR: controlnet file is invalid and does not contain a valid controlnet model.")
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
return (controlnet,)
class DiffControlNetLoader:
@ -816,6 +824,7 @@ class DiffControlNetLoader:
def load_controlnet(self, model, control_net_name):
controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name)
controlnet = comfy.controlnet.load_controlnet(controlnet_path, model)
populate_db_with_asset(["models", "controlnet"], control_net_name, controlnet_path)
return (controlnet,)
@ -923,6 +932,7 @@ class UNETLoader:
unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name)
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
populate_db_with_asset(["models", "diffusion-model"], unet_name, unet_path)
return (model,)
class CLIPLoader:
@ -950,6 +960,7 @@ class CLIPLoader:
clip_path = folder_paths.get_full_path_or_raise("text_encoders", clip_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
populate_db_with_asset(["models", "text-encoder"], clip_name, clip_path)
return (clip,)
class DualCLIPLoader:
@ -980,6 +991,8 @@ class DualCLIPLoader:
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
populate_db_with_asset(["models", "text-encoder"], clip_name1, clip_path1)
populate_db_with_asset(["models", "text-encoder"], clip_name2, clip_path2)
return (clip,)
class CLIPVisionLoader:
@ -997,6 +1010,7 @@ class CLIPVisionLoader:
clip_vision = comfy.clip_vision.load(clip_path)
if clip_vision is None:
raise RuntimeError("ERROR: clip vision file is invalid and does not contain a valid vision model.")
populate_db_with_asset(["models", "clip-vision"], clip_name, clip_path)
return (clip_vision,)
class CLIPVisionEncode:
@ -1031,6 +1045,7 @@ class StyleModelLoader:
def load_style_model(self, style_model_name):
style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name)
style_model = comfy.sd.load_style_model(style_model_path)
populate_db_with_asset(["models", "style-model"], style_model_name, style_model_path)
return (style_model,)
@ -1128,6 +1143,7 @@ class GLIGENLoader:
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name)
gligen = comfy.sd.load_gligen(gligen_path)
populate_db_with_asset(["models", "gligen"], gligen_name, gligen_path)
return (gligen,)
class GLIGENTextBoxApply:

View File

@ -33,7 +33,6 @@ from app.frontend_management import FrontendManager
from comfy_api.internal import _ComfyNodeInternal
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
@ -155,7 +154,6 @@ class PromptServer():
mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
@ -764,7 +762,6 @@ class PromptServer():
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.app.add_subapp('/internal', self.internal_routes.get_app())