mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
feat: add StyleAsset model and StyleProfileExtract node
- Add StyleAsset model to research_api/models.py for storing writing style profiles - Add StyleAsset db helpers (sync + async) to research_api/routes/_db_helpers.py - Add StyleAsset REST routes to research_api/routes/research_routes.py
This commit is contained in:
parent
efecc78f54
commit
24d79c08e6
@ -10,7 +10,8 @@ class ResearchExtension(ComfyExtension):
|
||||
from custom_nodes.research.paper_search import PaperSearch
|
||||
from custom_nodes.research.claim_extract import PaperClaimExtract
|
||||
from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble
|
||||
return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble]
|
||||
from custom_nodes.research.style_profile import StyleProfileExtract
|
||||
return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble, StyleProfileExtract]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ComfyExtension:
|
||||
|
||||
@ -124,3 +124,19 @@ class FeedItem(Base):
|
||||
status = Column(String, default="discovered") # discovered, ranked, presented, quick-reviewed, saved, ignored
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
|
||||
class StyleAsset(Base):
|
||||
__tablename__ = "style_assets"
|
||||
|
||||
id = Column(String, primary_key=True, default=new_id)
|
||||
title = Column(String, nullable=False)
|
||||
abstract_pattern = Column(String)
|
||||
intro_pattern = Column(String)
|
||||
methods_pattern = Column(String)
|
||||
tone_notes = Column(String)
|
||||
citation_format = Column(String)
|
||||
source_paper_refs = Column(String) # JSON string
|
||||
usage_notes = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
import asyncio
|
||||
from functools import partial
|
||||
from research_api.db import create_session
|
||||
from research_api.models import Project, Intent, PaperAsset, ClaimAsset, Source, FeedItem
|
||||
from research_api.models import Project, Intent, PaperAsset, ClaimAsset, Source, FeedItem, StyleAsset
|
||||
from app.database.models import to_dict
|
||||
|
||||
|
||||
@ -307,3 +307,75 @@ async def asyncio_create_feed_item(data):
|
||||
|
||||
async def asyncio_update_feed_item(item_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_feed_item, item_id, data))
|
||||
|
||||
|
||||
# StyleAsset helpers
|
||||
def _sync_list_styles():
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(StyleAsset).order_by(StyleAsset.updated_at.desc()))
|
||||
return [to_dict(s) for s in result.scalars().all()]
|
||||
|
||||
|
||||
def _sync_create_style(data):
|
||||
with create_session() as session:
|
||||
style = StyleAsset(**data)
|
||||
session.add(style)
|
||||
session.commit()
|
||||
session.refresh(style)
|
||||
return to_dict(style)
|
||||
|
||||
|
||||
def _sync_get_style(style_id):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id))
|
||||
s = result.scalar_one_or_none()
|
||||
return to_dict(s) if s else None
|
||||
|
||||
|
||||
def _sync_update_style(style_id, data):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id))
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
return None
|
||||
for key, value in data.items():
|
||||
if hasattr(style, key):
|
||||
setattr(style, key, value)
|
||||
session.commit()
|
||||
session.refresh(style)
|
||||
return to_dict(style)
|
||||
|
||||
|
||||
def _sync_delete_style(style_id):
|
||||
with create_session() as session:
|
||||
from sqlalchemy import select
|
||||
result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id))
|
||||
style = result.scalar_one_or_none()
|
||||
if not style:
|
||||
return None
|
||||
session.delete(style)
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def asyncio_list_styles():
|
||||
return await loop().run_in_executor(None, _sync_list_styles)
|
||||
|
||||
|
||||
async def asyncio_create_style(data):
|
||||
return await loop().run_in_executor(None, partial(_sync_create_style, data))
|
||||
|
||||
|
||||
async def asyncio_get_style(style_id):
|
||||
return await loop().run_in_executor(None, partial(_sync_get_style, style_id))
|
||||
|
||||
|
||||
async def asyncio_update_style(style_id, data):
|
||||
return await loop().run_in_executor(None, partial(_sync_update_style, style_id, data))
|
||||
|
||||
|
||||
async def asyncio_delete_style(style_id):
|
||||
return await loop().run_in_executor(None, partial(_sync_delete_style, style_id))
|
||||
|
||||
@ -21,6 +21,11 @@ from research_api.routes._db_helpers import (
|
||||
asyncio_list_feed,
|
||||
asyncio_create_feed_item,
|
||||
asyncio_update_feed_item,
|
||||
asyncio_list_styles,
|
||||
asyncio_create_style,
|
||||
asyncio_get_style,
|
||||
asyncio_update_style,
|
||||
asyncio_delete_style,
|
||||
)
|
||||
|
||||
|
||||
@ -181,3 +186,40 @@ class ResearchRoutes:
|
||||
if not item:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(item)
|
||||
|
||||
# Styles
|
||||
@self.routes.get("/research/assets/styles/")
|
||||
async def list_styles(request):
|
||||
styles = await asyncio_list_styles()
|
||||
return web.json_response(styles)
|
||||
|
||||
@self.routes.post("/research/assets/styles/")
|
||||
async def create_style(request):
|
||||
data = await request.json()
|
||||
style = await asyncio_create_style(data)
|
||||
return web.json_response(style, status=201)
|
||||
|
||||
@self.routes.get("/research/assets/styles/{style_id}")
|
||||
async def get_style(request):
|
||||
style_id = request.match_info["style_id"]
|
||||
style = await asyncio_get_style(style_id)
|
||||
if not style:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(style)
|
||||
|
||||
@self.routes.patch("/research/assets/styles/{style_id}")
|
||||
async def update_style(request):
|
||||
style_id = request.match_info["style_id"]
|
||||
data = await request.json()
|
||||
style = await asyncio_update_style(style_id, data)
|
||||
if not style:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response(style)
|
||||
|
||||
@self.routes.delete("/research/assets/styles/{style_id}")
|
||||
async def delete_style(request):
|
||||
style_id = request.match_info["style_id"]
|
||||
result = await asyncio_delete_style(style_id)
|
||||
if not result:
|
||||
return web.json_response({"error": "Not found"}, status=404)
|
||||
return web.json_response({"status": "deleted"})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user