From 6e94c13035a16c3cba07aced8d720313c540099c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=AF=BA=E6=96=AF=E8=B4=B9=E6=8B=89=E5=9B=BE?= <1132505822@qq.com> Date: Sun, 12 Apr 2026 17:37:24 +0800 Subject: [PATCH] feat: add ReferencePaperSelect node --- custom_nodes/research/__init__.py | 3 +- .../research/reference_paper_select.py | 59 +++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 custom_nodes/research/reference_paper_select.py diff --git a/custom_nodes/research/__init__.py b/custom_nodes/research/__init__.py index e97efb426..15300b2a3 100644 --- a/custom_nodes/research/__init__.py +++ b/custom_nodes/research/__init__.py @@ -11,7 +11,8 @@ class ResearchExtension(ComfyExtension): from custom_nodes.research.claim_extract import PaperClaimExtract from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble from custom_nodes.research.style_profile import StyleProfileExtract - return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble, StyleProfileExtract] + from custom_nodes.research.reference_paper_select import ReferencePaperSelect + return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble, StyleProfileExtract, ReferencePaperSelect] async def comfy_entrypoint() -> ComfyExtension: diff --git a/custom_nodes/research/reference_paper_select.py b/custom_nodes/research/reference_paper_select.py new file mode 100644 index 000000000..08f681302 --- /dev/null +++ b/custom_nodes/research/reference_paper_select.py @@ -0,0 +1,59 @@ +"""ReferencePaperSelect node - select papers from project for canvas use.""" +import json +from typing_extensions import override + +from comfy_api.latest import ComfyNode, io + + +class ReferencePaperSelect(io.ComfyNode): + """Select papers from a project to load onto the canvas for reference.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ReferencePaperSelect", + display_name="Select Papers", + category="Research", + inputs=[ + io.String.Input( + "project_id", + display_name="Project ID", + default="", + ), + io.Int.Input( + "max_papers", + display_name="Max Papers", + default=5, + min=1, + max=20, + step=1, + ), + ], + outputs=[ + io.String.Output(display_name="Selected Papers (JSON)"), + ], + ) + + @classmethod + def execute(cls, project_id: str, max_papers: int) -> io.NodeOutput: + if not project_id.strip(): + return io.NodeOutput(selected_papers=json.dumps([])) + + # Phase 1: Returns mock data + # Phase 2+ would query the actual project papers from DB + return io.NodeOutput(selected_papers=json.dumps([ + { + "paper_id": "mock-id-1", + "title": "Sample Paper 1", + "authors": ["Author A", "Author B"], + "abstract": "This is a sample abstract for demonstration.", + "year": "2024", + }, + { + "paper_id": "mock-id-2", + "title": "Sample Paper 2", + "authors": ["Author C"], + "abstract": "Another sample abstract.", + "year": "2023", + } + ][:max_papers], indent=2))