mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
feat: add research API routes (aiohttp, projects, papers, claims, sources, feed)
This commit is contained in:
parent
63df766808
commit
2199e56581
1
research_api/routes/__init__.py
Normal file
1
research_api/routes/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
# Research API routes package
|
||||
309
research_api/routes/_db_helpers.py
Normal file
309
research_api/routes/_db_helpers.py
Normal file
@ -0,0 +1,309 @@
|
||||
"""Async DB helpers that wrap sync SQLAlchemy with run_in_executor."""
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from research_api.db import create_session
|
||||
from research_api.models import Project, Intent, PaperAsset, ClaimAsset, Source, FeedItem
|
||||
from app.database.models import to_dict
|
||||
|
||||
|
||||
def _sync_list_projects():
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(Project).order_by(Project.updated_at.desc()))
|
||||
return [to_dict(p) for p in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_project(data):
|
||||
with create_session() as session:
|
||||
project = Project(**data)
|
||||
session.add(project)
|
||||
session.commit()
|
||||
session.refresh(project)
|
||||
return to_dict(project)
|
||||
|
||||
|
||||
def _sync_get_project(project_id):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(Project).where(Project.id == project_id))
|
||||
p = result.scalar_one_or_none()
|
||||
return to_dict(p) if p else None
|
||||
|
||||
|
||||
def _sync_list_intents(project_id):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(
|
||||
select(Intent).where(Intent.project_id == project_id).order_by(Intent.priority.desc())
|
||||
)
|
||||
return [to_dict(i) for i in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_intent(data):
|
||||
with create_session() as session:
|
||||
intent = Intent(**data)
|
||||
session.add(intent)
|
||||
session.commit()
|
||||
session.refresh(intent)
|
||||
return to_dict(intent)
|
||||
|
||||
|
||||
def _sync_update_intent(intent_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(Intent).where(Intent.id == intent_id))
|
||||
intent = result.scalar_one_or_none()
|
||||
if not intent:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(intent, key):
|
||||
setattr(intent, key, value)
|
||||
session.commit()
|
||||
session.refresh(intent)
|
||||
return to_dict(intent)
|
||||
|
||||
|
||||
def _sync_list_papers(library_status=None, read_status=None):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
query = select(PaperAsset)
|
||||
if library_status:
|
||||
query = query.where(PaperAsset.library_status == library_status)
|
||||
if read_status:
|
||||
query = query.where(PaperAsset.read_status == read_status)
|
||||
query = query.order_by(PaperAsset.updated_at.desc())
|
||||
result = session.execute(query)
|
||||
return [to_dict(p) for p in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_paper(data):
|
||||
with create_session() as session:
|
||||
paper = PaperAsset(**data)
|
||||
session.add(paper)
|
||||
session.commit()
|
||||
session.refresh(paper)
|
||||
return to_dict(paper)
|
||||
|
||||
|
||||
def _sync_get_paper(paper_id):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(PaperAsset).where(PaperAsset.id == paper_id))
|
||||
p = result.scalar_one_or_none()
|
||||
return to_dict(p) if p else None
|
||||
|
||||
|
||||
def _sync_update_paper(paper_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(PaperAsset).where(PaperAsset.id == paper_id))
|
||||
paper = result.scalar_one_or_none()
|
||||
if not paper:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(paper, key):
|
||||
setattr(paper, key, value)
|
||||
session.commit()
|
||||
session.refresh(paper)
|
||||
return to_dict(paper)
|
||||
|
||||
|
||||
def _sync_list_claims(project_id=None, support_level=None):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
query = select(ClaimAsset)
|
||||
if project_id:
|
||||
query = query.where(ClaimAsset.project_id == project_id)
|
||||
if support_level:
|
||||
query = query.where(ClaimAsset.support_level == support_level)
|
||||
query = query.order_by(ClaimAsset.updated_at.desc())
|
||||
result = session.execute(query)
|
||||
return [to_dict(c) for c in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_claim(data):
|
||||
with create_session() as session:
|
||||
claim = ClaimAsset(**data)
|
||||
session.add(claim)
|
||||
session.commit()
|
||||
session.refresh(claim)
|
||||
return to_dict(claim)
|
||||
|
||||
|
||||
def _sync_update_claim(claim_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(ClaimAsset).where(ClaimAsset.id == claim_id))
|
||||
claim = result.scalar_one_or_none()
|
||||
if not claim:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(claim, key):
|
||||
setattr(claim, key, value)
|
||||
session.commit()
|
||||
session.refresh(claim)
|
||||
return to_dict(claim)
|
||||
|
||||
|
||||
def _sync_list_sources():
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(Source).order_by(Source.priority.desc()))
|
||||
return [to_dict(s) for s in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_source(data):
|
||||
with create_session() as session:
|
||||
source = Source(**data)
|
||||
session.add(source)
|
||||
session.commit()
|
||||
session.refresh(source)
|
||||
return to_dict(source)
|
||||
|
||||
|
||||
def _sync_update_source(source_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(Source).where(Source.id == source_id))
|
||||
source = result.scalar_one_or_none()
|
||||
if not source:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(source, key):
|
||||
setattr(source, key, value)
|
||||
session.commit()
|
||||
session.refresh(source)
|
||||
return to_dict(source)
|
||||
|
||||
|
||||
def _sync_get_today_feed():
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(
|
||||
select(FeedItem)
|
||||
.where(FeedItem.status.in_(["discovered", "ranked", "presented"]))
|
||||
.order_by(FeedItem.rank_score.desc())
|
||||
.limit(20)
|
||||
)
|
||||
return [to_dict(i) for i in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_list_feed(source_id=None, status=None):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
query = select(FeedItem)
|
||||
if source_id:
|
||||
query = query.where(FeedItem.source_id == source_id)
|
||||
if status:
|
||||
query = query.where(FeedItem.status == status)
|
||||
query = query.order_by(FeedItem.rank_score.desc())
|
||||
result = session.execute(query)
|
||||
return [to_dict(i) for i in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_feed_item(data):
|
||||
with create_session() as session:
|
||||
item = FeedItem(**data)
|
||||
session.add(item)
|
||||
session.commit()
|
||||
session.refresh(item)
|
||||
return to_dict(item)
|
||||
|
||||
|
||||
def _sync_update_feed_item(item_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(FeedItem).where(FeedItem.id == item_id))
|
||||
item = result.scalar_one_or_none()
|
||||
if not item:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(item, key):
|
||||
setattr(item, key, value)
|
||||
session.commit()
|
||||
session.refresh(item)
|
||||
return to_dict(item)
|
||||
|
||||
|
||||
# Async wrappers using run_in_executor
|
||||
loop = asyncio.get_event_loop
|
||||
|
||||
|
||||
async def asyncio_get_projects():
|
||||
return await loop().run_in_executor(None, _sync_list_projects)
|
||||
|
||||
|
||||
async def asyncio_create_project(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_project, data))
|
||||
|
||||
|
||||
async def asyncio_get_project(project_id):
|
||||
return await loop().run_in_executor(None, partial(_sync_get_project, project_id))
|
||||
|
||||
|
||||
async def asyncio_list_intents(project_id):
|
||||
return await loop().run_in_executor(None, partial(_sync_list_intents, project_id))
|
||||
|
||||
|
||||
async def asyncio_create_intent(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_intent, data))
|
||||
|
||||
|
||||
async def asyncio_update_intent(intent_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_intent, intent_id, data))
|
||||
|
||||
|
||||
async def asyncio_list_papers(library_status=None, read_status=None):
|
||||
return await loop().run_in_executor(None, partial(_sync_list_papers, library_status, read_status))
|
||||
|
||||
|
||||
async def asyncio_create_paper(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_paper, data))
|
||||
|
||||
|
||||
async def asyncio_get_paper(paper_id):
|
||||
return await loop().run_in_executor(None, partial(_sync_get_paper, paper_id))
|
||||
|
||||
|
||||
async def asyncio_update_paper(paper_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_paper, paper_id, data))
|
||||
|
||||
|
||||
async def asyncio_list_claims(project_id=None, support_level=None):
|
||||
return await loop().run_in_executor(None, partial(_sync_list_claims, project_id, support_level))
|
||||
|
||||
|
||||
async def asyncio_create_claim(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_claim, data))
|
||||
|
||||
|
||||
async def asyncio_update_claim(claim_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_claim, claim_id, data))
|
||||
|
||||
|
||||
async def asyncio_list_sources():
|
||||
return await loop().run_in_executor(None, _sync_list_sources)
|
||||
|
||||
|
||||
async def asyncio_create_source(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_source, data))
|
||||
|
||||
|
||||
async def asyncio_update_source(source_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_source, source_id, data))
|
||||
|
||||
|
||||
async def asyncio_get_today_feed():
|
||||
return await loop().run_in_executor(None, _sync_get_today_feed)
|
||||
|
||||
|
||||
async def asyncio_list_feed(source_id=None, status=None):
|
||||
return await loop().run_in_executor(None, partial(_sync_list_feed, source_id, status))
|
||||
|
||||
|
||||
async def asyncio_create_feed_item(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_feed_item, data))
|
||||
|
||||
|
||||
async def asyncio_update_feed_item(item_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_feed_item, item_id, data))
|
||||
183
research_api/routes/research_routes.py
Normal file
183
research_api/routes/research_routes.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Research API routes using aiohttp."""
|
||||
from aiohttp import web
|
||||
from research_api.routes._db_helpers import (
|
||||
asyncio_get_projects,
|
||||
asyncio_create_project,
|
||||
asyncio_get_project,
|
||||
asyncio_list_intents,
|
||||
asyncio_create_intent,
|
||||
asyncio_update_intent,
|
||||
asyncio_list_papers,
|
||||
asyncio_create_paper,
|
||||
asyncio_get_paper,
|
||||
asyncio_update_paper,
|
||||
asyncio_list_claims,
|
||||
asyncio_create_claim,
|
||||
asyncio_update_claim,
|
||||
asyncio_list_sources,
|
||||
asyncio_create_source,
|
||||
asyncio_update_source,
|
||||
asyncio_get_today_feed,
|
||||
asyncio_list_feed,
|
||||
asyncio_create_feed_item,
|
||||
asyncio_update_feed_item,
|
||||
)
|
||||
|
||||
|
||||
class ResearchRoutes:
|
||||
def __init__(self):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: web.Application = None
|
||||
|
||||
def get_app(self) -> web.Application:
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
self._app.add_routes(self.routes)
|
||||
return self._app
|
||||
|
||||
def setup_routes(self):
|
||||
# Projects
|
||||
@self.routes.get("/research/projects/")
|
||||
async def list_projects(request):
|
||||
projects = await asyncio_get_projects()
|
||||
return web.json_response(projects)
|
||||
|
||||
@self.routes.post("/research/projects/")
|
||||
async def create_project(request):
|
||||
data = await request.json()
|
||||
project = await asyncio_create_project(data)
|
||||
return web.json_response(project)
|
||||
|
||||
@self.routes.get("/research/projects/{project_id}")
|
||||
async def get_project(request):
|
||||
project_id = request.match_info["project_id"]
|
||||
project = await asyncio_get_project(project_id)
|
||||
if not project:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(project)
|
||||
|
||||
@self.routes.get("/research/projects/{project_id}/intents")
|
||||
async def list_intents(request):
|
||||
project_id = request.match_info["project_id"]
|
||||
intents = await asyncio_list_intents(project_id)
|
||||
return web.json_response(intents)
|
||||
|
||||
@self.routes.post("/research/projects/{project_id}/intents")
|
||||
async def create_intent(request):
|
||||
project_id = request.match_info["project_id"]
|
||||
data = await request.json()
|
||||
data["project_id"] = project_id
|
||||
intent = await asyncio_create_intent(data)
|
||||
return web.json_response(intent)
|
||||
|
||||
@self.routes.patch("/research/projects/intents/{intent_id}")
|
||||
async def update_intent(request):
|
||||
intent_id = request.match_info["intent_id"]
|
||||
data = await request.json()
|
||||
intent = await asyncio_update_intent(intent_id, data)
|
||||
if not intent:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(intent)
|
||||
|
||||
# Papers
|
||||
@self.routes.get("/research/papers/")
|
||||
async def list_papers(request):
|
||||
library_status = request.query.get("library_status")
|
||||
read_status = request.query.get("read_status")
|
||||
papers = await asyncio_list_papers(library_status, read_status)
|
||||
return web.json_response(papers)
|
||||
|
||||
@self.routes.post("/research/papers/")
|
||||
async def create_paper(request):
|
||||
data = await request.json()
|
||||
paper = await asyncio_create_paper(data)
|
||||
return web.json_response(paper)
|
||||
|
||||
@self.routes.get("/research/papers/{paper_id}")
|
||||
async def get_paper(request):
|
||||
paper_id = request.match_info["paper_id"]
|
||||
paper = await asyncio_get_paper(paper_id)
|
||||
if not paper:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(paper)
|
||||
|
||||
@self.routes.patch("/research/papers/{paper_id}")
|
||||
async def update_paper(request):
|
||||
paper_id = request.match_info["paper_id"]
|
||||
data = await request.json()
|
||||
paper = await asyncio_update_paper(paper_id, data)
|
||||
if not paper:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(paper)
|
||||
|
||||
# Claims
|
||||
@self.routes.get("/research/claims/")
|
||||
async def list_claims(request):
|
||||
project_id = request.query.get("project_id")
|
||||
support_level = request.query.get("support_level")
|
||||
claims = await asyncio_list_claims(project_id, support_level)
|
||||
return web.json_response(claims)
|
||||
|
||||
@self.routes.post("/research/claims/")
|
||||
async def create_claim(request):
|
||||
data = await request.json()
|
||||
claim = await asyncio_create_claim(data)
|
||||
return web.json_response(claim)
|
||||
|
||||
@self.routes.patch("/research/claims/{claim_id}")
|
||||
async def update_claim(request):
|
||||
claim_id = request.match_info["claim_id"]
|
||||
data = await request.json()
|
||||
claim = await asyncio_update_claim(claim_id, data)
|
||||
if not claim:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(claim)
|
||||
|
||||
# Sources
|
||||
@self.routes.get("/research/sources/")
|
||||
async def list_sources(request):
|
||||
sources = await asyncio_list_sources()
|
||||
return web.json_response(sources)
|
||||
|
||||
@self.routes.post("/research/sources/")
|
||||
async def create_source(request):
|
||||
data = await request.json()
|
||||
source = await asyncio_create_source(data)
|
||||
return web.json_response(source)
|
||||
|
||||
@self.routes.patch("/research/sources/{source_id}")
|
||||
async def update_source(request):
|
||||
source_id = request.match_info["source_id"]
|
||||
data = await request.json()
|
||||
source = await asyncio_update_source(source_id, data)
|
||||
if not source:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(source)
|
||||
|
||||
# Feed
|
||||
@self.routes.get("/research/feed/today")
|
||||
async def get_today_feed(request):
|
||||
items = await asyncio_get_today_feed()
|
||||
return web.json_response(items)
|
||||
|
||||
@self.routes.get("/research/feed/")
|
||||
async def list_feed(request):
|
||||
source_id = request.query.get("source_id")
|
||||
status = request.query.get("status")
|
||||
items = await asyncio_list_feed(source_id, status)
|
||||
return web.json_response(items)
|
||||
|
||||
@self.routes.post("/research/feed/")
|
||||
async def create_feed_item(request):
|
||||
data = await request.json()
|
||||
item = await asyncio_create_feed_item(data)
|
||||
return web.json_response(item)
|
||||
|
||||
@self.routes.patch("/research/feed/{item_id}")
|
||||
async def update_feed_item(request):
|
||||
item_id = request.match_info["item_id"]
|
||||
data = await request.json()
|
||||
item = await asyncio_update_feed_item(item_id, data)
|
||||
if not item:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(item)
|
||||
@ -45,6 +45,7 @@ from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from research_api.routes.research_routes import ResearchRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
# Import cache control middleware
|
||||
@ -209,6 +210,8 @@ class PromptServer():
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.research_routes = ResearchRoutes()
|
||||
self.research_routes.setup_routes()
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
self.loop = loop
|
||||
@ -1048,6 +1051,7 @@ class PromptServer():
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
self.app.add_subapp('/research', self.research_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
# This is very useful for frontend dev server, which need to forward
|
||||
|
||||
Loading…
Reference in New Issue
Block a user