diff --git a/custom_nodes/research/__init__.py b/custom_nodes/research/__init__.py index 24daf8e20..e97efb426 100644 --- a/custom_nodes/research/__init__.py +++ b/custom_nodes/research/__init__.py @@ -10,7 +10,8 @@ class ResearchExtension(ComfyExtension): from custom_nodes.research.paper_search import PaperSearch from custom_nodes.research.claim_extract import PaperClaimExtract from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble - return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble] + from custom_nodes.research.style_profile import StyleProfileExtract + return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble, StyleProfileExtract] async def comfy_entrypoint() -> ComfyExtension: diff --git a/research_api/models.py b/research_api/models.py index 0d139dd3f..33f56c03d 100644 --- a/research_api/models.py +++ b/research_api/models.py @@ -124,3 +124,19 @@ class FeedItem(Base): status = Column(String, default="discovered") # discovered, ranked, presented, quick-reviewed, saved, ignored created_at = Column(DateTime, default=datetime.utcnow) updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + +class StyleAsset(Base): + __tablename__ = "style_assets" + + id = Column(String, primary_key=True, default=new_id) + title = Column(String, nullable=False) + abstract_pattern = Column(String) + intro_pattern = Column(String) + methods_pattern = Column(String) + tone_notes = Column(String) + citation_format = Column(String) + source_paper_refs = Column(String) # JSON string + usage_notes = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) diff --git a/research_api/routes/_db_helpers.py b/research_api/routes/_db_helpers.py index 7c6fa666b..051c93c19 100644 --- a/research_api/routes/_db_helpers.py +++ b/research_api/routes/_db_helpers.py @@ -2,7 +2,7 @@ import asyncio from functools import partial from research_api.db import create_session -from research_api.models import Project, Intent, PaperAsset, ClaimAsset, Source, FeedItem +from research_api.models import Project, Intent, PaperAsset, ClaimAsset, Source, FeedItem, StyleAsset from app.database.models import to_dict @@ -307,3 +307,75 @@ async def asyncio_create_feed_item(data): async def asyncio_update_feed_item(item_id, data): return await loop().run_in_executor(None, partial(_sync_update_feed_item, item_id, data)) + + +# StyleAsset helpers +def _sync_list_styles(): + with create_session() as session: + from sqlalchemy import select + result = session.execute(select(StyleAsset).order_by(StyleAsset.updated_at.desc())) + return [to_dict(s) for s in result.scalars().all()] + + +def _sync_create_style(data): + with create_session() as session: + style = StyleAsset(**data) + session.add(style) + session.commit() + session.refresh(style) + return to_dict(style) + + +def _sync_get_style(style_id): + with create_session() as session: + from sqlalchemy import select + result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id)) + s = result.scalar_one_or_none() + return to_dict(s) if s else None + + +def _sync_update_style(style_id, data): + with create_session() as session: + from sqlalchemy import select + result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id)) + style = result.scalar_one_or_none() + if not style: + return None + for key, value in data.items(): + if hasattr(style, key): + setattr(style, key, value) + session.commit() + session.refresh(style) + return to_dict(style) + + +def _sync_delete_style(style_id): + with create_session() as session: + from sqlalchemy import select + result = session.execute(select(StyleAsset).where(StyleAsset.id == style_id)) + style = result.scalar_one_or_none() + if not style: + return None + session.delete(style) + session.commit() + return True + + +async def asyncio_list_styles(): + return await loop().run_in_executor(None, _sync_list_styles) + + +async def asyncio_create_style(data): + return await loop().run_in_executor(None, partial(_sync_create_style, data)) + + +async def asyncio_get_style(style_id): + return await loop().run_in_executor(None, partial(_sync_get_style, style_id)) + + +async def asyncio_update_style(style_id, data): + return await loop().run_in_executor(None, partial(_sync_update_style, style_id, data)) + + +async def asyncio_delete_style(style_id): + return await loop().run_in_executor(None, partial(_sync_delete_style, style_id)) diff --git a/research_api/routes/research_routes.py b/research_api/routes/research_routes.py index 299a9d54f..1d64c77c3 100644 --- a/research_api/routes/research_routes.py +++ b/research_api/routes/research_routes.py @@ -21,6 +21,11 @@ from research_api.routes._db_helpers import ( asyncio_list_feed, asyncio_create_feed_item, asyncio_update_feed_item, + asyncio_list_styles, + asyncio_create_style, + asyncio_get_style, + asyncio_update_style, + asyncio_delete_style, ) @@ -181,3 +186,40 @@ class ResearchRoutes: if not item: return web.json_response({"error": "Not found"}, status=404) return web.json_response(item) + + # Styles + @self.routes.get("/research/assets/styles/") + async def list_styles(request): + styles = await asyncio_list_styles() + return web.json_response(styles) + + @self.routes.post("/research/assets/styles/") + async def create_style(request): + data = await request.json() + style = await asyncio_create_style(data) + return web.json_response(style, status=201) + + @self.routes.get("/research/assets/styles/{style_id}") + async def get_style(request): + style_id = request.match_info["style_id"] + style = await asyncio_get_style(style_id) + if not style: + return web.json_response({"error": "Not found"}, status=404) + return web.json_response(style) + + @self.routes.patch("/research/assets/styles/{style_id}") + async def update_style(request): + style_id = request.match_info["style_id"] + data = await request.json() + style = await asyncio_update_style(style_id, data) + if not style: + return web.json_response({"error": "Not found"}, status=404) + return web.json_response(style) + + @self.routes.delete("/research/assets/styles/{style_id}") + async def delete_style(request): + style_id = request.match_info["style_id"] + result = await asyncio_delete_style(style_id) + if not result: + return web.json_response({"error": "Not found"}, status=404) + return web.json_response({"status": "deleted"})