feat: add first 3 research nodes (PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble)

This commit is contained in:
诺斯费拉图 2026-04-12 17:12:19 +08:00
parent 2199e56581
commit be9241aa0f
4 changed files with 201 additions and 0 deletions

View File

@ -0,0 +1,17 @@
"""Research Workbench custom nodes for ComfyUI."""
from typing import List
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class ResearchExtension(ComfyExtension):
@override
async def get_node_list(self) -> List[type[io.ComfyNode]]:
from custom_nodes.research.paper_search import PaperSearch
from custom_nodes.research.claim_extract import PaperClaimExtract
from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble
return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble]
async def comfy_entrypoint() -> ComfyExtension:
return ResearchExtension()

View File

@ -0,0 +1,68 @@
"""PaperClaimExtract node - extract claims from paper text."""
import json
import re
from typing_extensions import override
from comfy_api.latest import ComfyNode, io
class PaperClaimExtract(io.ComfyNode):
"""Extract scientific claims from paper text or abstract."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="PaperClaimExtract",
display_name="Extract Claims",
category="Research",
inputs=[
io.String.Input(
"paper_text",
display_name="Paper Text / Abstract",
default="",
multiline=True,
),
io.String.Input(
"claim_types",
display_name="Claim Types (comma-separated)",
default="performance,robustness,generalization",
),
],
outputs=[
io.String.Output(display_name="Claims (JSON)"),
],
)
@classmethod
def execute(cls, paper_text: str, claim_types: str) -> io.NodeOutput:
if not paper_text.strip():
return io.NodeOutput(claims=json.dumps([]))
claim_type_list = [t.strip() for t in claim_types.split(",")]
sentences = re.split(r'(?<=[.!?])\s+', paper_text.strip())
claims = []
for sent in sentences:
sent = sent.strip()
if not sent or len(sent) < 20:
continue
lower = sent.lower()
# Heuristic: sentences with comparison/performance words
if any(w in lower for w in ["achieves", "outperforms", "improves", "reduces", "increases", "demonstrates", "shows", "provides", "enables"]):
claims.append({
"text": sent,
"type": _classify_claim(sent, claim_type_list),
"support_level": "unsupported",
})
return io.NodeOutput(claims=json.dumps(claims[:10], indent=2)) # limit to 10
def _classify_claim(sentence: str, claim_types) -> str:
lower = sentence.lower()
if any(w in lower for w in ["accuracy", "performance", "score", "auc", "f1", "precision", "recall"]):
return "performance"
if any(w in lower for w in ["robust", "noise", "attack", "adversarial", "corruption"]):
return "robustness"
if any(w in lower for w in ["general", "transfer", "cross-domain", "cross-dataset"]):
return "generalization"
return claim_types[0] if claim_types else "other"

View File

@ -0,0 +1,54 @@
"""ClaimEvidenceAssemble node - assemble claim × evidence matrix."""
import json
from typing_extensions import override
from comfy_api.latest import ComfyNode, io
class ClaimEvidenceAssemble(io.ComfyNode):
"""Assemble claims and papers into a structured evidence matrix."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ClaimEvidenceAssemble",
display_name="Assemble Evidence",
category="Research",
inputs=[
io.String.Input(
"claims",
display_name="Claims (JSON)",
default="[]",
multiline=True,
),
io.String.Input(
"papers",
display_name="Papers (JSON)",
default="[]",
multiline=True,
),
],
outputs=[
io.String.Output(display_name="Evidence Matrix (JSON)"),
],
)
@classmethod
def execute(cls, claims: str, papers: str) -> io.NodeOutput:
try:
claims_list = json.loads(claims) if claims else []
papers_list = json.loads(papers) if papers else []
matrix = []
for claim in claims_list:
claim_text = claim.get("text", "") if isinstance(claim, dict) else str(claim)
matrix.append({
"claim": claim_text,
"claim_type": claim.get("type", "") if isinstance(claim, dict) else "",
"support_level": claim.get("support_level", "unsupported") if isinstance(claim, dict) else "unsupported",
"evidence": [],
"gap_flags": ["No linked evidence yet"],
})
return io.NodeOutput(evidence_matrix=json.dumps(matrix, indent=2))
except json.JSONDecodeError as e:
return io.NodeOutput(evidence_matrix=json.dumps({"error": f"JSON parse error: {e}"}))

View File

@ -0,0 +1,62 @@
"""PaperSearch node - search papers via academic APIs."""
import json
import urllib.request
import urllib.parse
from typing_extensions import override
from comfy_api.latest import ComfyNode, io
class PaperSearch(io.ComfyNode):
"""Search academic papers from Semantic Scholar."""
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="PaperSearch",
display_name="Paper Search",
category="Research",
inputs=[
io.String.Input(
"query",
display_name="Search Query",
default="",
),
io.Int.Input(
"max_results",
display_name="Max Results",
default=5,
min=1,
max=20,
step=1,
),
],
outputs=[
io.String.Output(display_name="Papers (JSON)"),
],
)
@classmethod
def execute(cls, query: str, max_results: int) -> io.NodeOutput:
if not query.strip():
return io.NodeOutput(papers=json.dumps([]))
encoded_query = urllib.parse.quote(query)
url = f"https://api.semanticscholar.org/graph/v1/paper/search?query={encoded_query}&limit={max_results}&fields=title,authors,abstract,year,journal,venue"
try:
req = urllib.request.Request(url, headers={"Accept": "application/json"})
with urllib.request.urlopen(req, timeout=15) as response:
data = json.loads(response.read().decode())
papers = []
for item in data.get("data", []):
papers.append({
"title": item.get("title", ""),
"authors": [a.get("name", "") for a in item.get("authors", [])],
"abstract": item.get("abstract", ""),
"year": item.get("year", ""),
"venue": item.get("venue", ""),
"paper_id": item.get("paperId", ""),
})
return io.NodeOutput(papers=json.dumps(papers, indent=2))
except Exception as e:
return io.NodeOutput(papers=json.dumps({"error": str(e)}))