mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
feat: add first 3 research nodes (PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble)
This commit is contained in:
parent
2199e56581
commit
be9241aa0f
17
custom_nodes/research/__init__.py
Normal file
17
custom_nodes/research/__init__.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
"""Research Workbench custom nodes for ComfyUI."""
|
||||||
|
from typing import List
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> List[type[io.ComfyNode]]:
|
||||||
|
from custom_nodes.research.paper_search import PaperSearch
|
||||||
|
from custom_nodes.research.claim_extract import PaperClaimExtract
|
||||||
|
from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble
|
||||||
|
return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ComfyExtension:
|
||||||
|
return ResearchExtension()
|
||||||
68
custom_nodes/research/claim_extract.py
Normal file
68
custom_nodes/research/claim_extract.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
"""PaperClaimExtract node - extract claims from paper text."""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyNode, io
|
||||||
|
|
||||||
|
|
||||||
|
class PaperClaimExtract(io.ComfyNode):
|
||||||
|
"""Extract scientific claims from paper text or abstract."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="PaperClaimExtract",
|
||||||
|
display_name="Extract Claims",
|
||||||
|
category="Research",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input(
|
||||||
|
"paper_text",
|
||||||
|
display_name="Paper Text / Abstract",
|
||||||
|
default="",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
io.String.Input(
|
||||||
|
"claim_types",
|
||||||
|
display_name="Claim Types (comma-separated)",
|
||||||
|
default="performance,robustness,generalization",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(display_name="Claims (JSON)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, paper_text: str, claim_types: str) -> io.NodeOutput:
|
||||||
|
if not paper_text.strip():
|
||||||
|
return io.NodeOutput(claims=json.dumps([]))
|
||||||
|
|
||||||
|
claim_type_list = [t.strip() for t in claim_types.split(",")]
|
||||||
|
sentences = re.split(r'(?<=[.!?])\s+', paper_text.strip())
|
||||||
|
|
||||||
|
claims = []
|
||||||
|
for sent in sentences:
|
||||||
|
sent = sent.strip()
|
||||||
|
if not sent or len(sent) < 20:
|
||||||
|
continue
|
||||||
|
lower = sent.lower()
|
||||||
|
# Heuristic: sentences with comparison/performance words
|
||||||
|
if any(w in lower for w in ["achieves", "outperforms", "improves", "reduces", "increases", "demonstrates", "shows", "provides", "enables"]):
|
||||||
|
claims.append({
|
||||||
|
"text": sent,
|
||||||
|
"type": _classify_claim(sent, claim_type_list),
|
||||||
|
"support_level": "unsupported",
|
||||||
|
})
|
||||||
|
|
||||||
|
return io.NodeOutput(claims=json.dumps(claims[:10], indent=2)) # limit to 10
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_claim(sentence: str, claim_types) -> str:
|
||||||
|
lower = sentence.lower()
|
||||||
|
if any(w in lower for w in ["accuracy", "performance", "score", "auc", "f1", "precision", "recall"]):
|
||||||
|
return "performance"
|
||||||
|
if any(w in lower for w in ["robust", "noise", "attack", "adversarial", "corruption"]):
|
||||||
|
return "robustness"
|
||||||
|
if any(w in lower for w in ["general", "transfer", "cross-domain", "cross-dataset"]):
|
||||||
|
return "generalization"
|
||||||
|
return claim_types[0] if claim_types else "other"
|
||||||
54
custom_nodes/research/evidence_assemble.py
Normal file
54
custom_nodes/research/evidence_assemble.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
"""ClaimEvidenceAssemble node - assemble claim × evidence matrix."""
|
||||||
|
import json
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyNode, io
|
||||||
|
|
||||||
|
|
||||||
|
class ClaimEvidenceAssemble(io.ComfyNode):
|
||||||
|
"""Assemble claims and papers into a structured evidence matrix."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ClaimEvidenceAssemble",
|
||||||
|
display_name="Assemble Evidence",
|
||||||
|
category="Research",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input(
|
||||||
|
"claims",
|
||||||
|
display_name="Claims (JSON)",
|
||||||
|
default="[]",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
io.String.Input(
|
||||||
|
"papers",
|
||||||
|
display_name="Papers (JSON)",
|
||||||
|
default="[]",
|
||||||
|
multiline=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(display_name="Evidence Matrix (JSON)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, claims: str, papers: str) -> io.NodeOutput:
|
||||||
|
try:
|
||||||
|
claims_list = json.loads(claims) if claims else []
|
||||||
|
papers_list = json.loads(papers) if papers else []
|
||||||
|
|
||||||
|
matrix = []
|
||||||
|
for claim in claims_list:
|
||||||
|
claim_text = claim.get("text", "") if isinstance(claim, dict) else str(claim)
|
||||||
|
matrix.append({
|
||||||
|
"claim": claim_text,
|
||||||
|
"claim_type": claim.get("type", "") if isinstance(claim, dict) else "",
|
||||||
|
"support_level": claim.get("support_level", "unsupported") if isinstance(claim, dict) else "unsupported",
|
||||||
|
"evidence": [],
|
||||||
|
"gap_flags": ["No linked evidence yet"],
|
||||||
|
})
|
||||||
|
|
||||||
|
return io.NodeOutput(evidence_matrix=json.dumps(matrix, indent=2))
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return io.NodeOutput(evidence_matrix=json.dumps({"error": f"JSON parse error: {e}"}))
|
||||||
62
custom_nodes/research/paper_search.py
Normal file
62
custom_nodes/research/paper_search.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
"""PaperSearch node - search papers via academic APIs."""
|
||||||
|
import json
|
||||||
|
import urllib.request
|
||||||
|
import urllib.parse
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyNode, io
|
||||||
|
|
||||||
|
|
||||||
|
class PaperSearch(io.ComfyNode):
|
||||||
|
"""Search academic papers from Semantic Scholar."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="PaperSearch",
|
||||||
|
display_name="Paper Search",
|
||||||
|
category="Research",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input(
|
||||||
|
"query",
|
||||||
|
display_name="Search Query",
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"max_results",
|
||||||
|
display_name="Max Results",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=20,
|
||||||
|
step=1,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.String.Output(display_name="Papers (JSON)"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, query: str, max_results: int) -> io.NodeOutput:
|
||||||
|
if not query.strip():
|
||||||
|
return io.NodeOutput(papers=json.dumps([]))
|
||||||
|
|
||||||
|
encoded_query = urllib.parse.quote(query)
|
||||||
|
url = f"https://api.semanticscholar.org/graph/v1/paper/search?query={encoded_query}&limit={max_results}&fields=title,authors,abstract,year,journal,venue"
|
||||||
|
|
||||||
|
try:
|
||||||
|
req = urllib.request.Request(url, headers={"Accept": "application/json"})
|
||||||
|
with urllib.request.urlopen(req, timeout=15) as response:
|
||||||
|
data = json.loads(response.read().decode())
|
||||||
|
papers = []
|
||||||
|
for item in data.get("data", []):
|
||||||
|
papers.append({
|
||||||
|
"title": item.get("title", ""),
|
||||||
|
"authors": [a.get("name", "") for a in item.get("authors", [])],
|
||||||
|
"abstract": item.get("abstract", ""),
|
||||||
|
"year": item.get("year", ""),
|
||||||
|
"venue": item.get("venue", ""),
|
||||||
|
"paper_id": item.get("paperId", ""),
|
||||||
|
})
|
||||||
|
return io.NodeOutput(papers=json.dumps(papers, indent=2))
|
||||||
|
except Exception as e:
|
||||||
|
return io.NodeOutput(papers=json.dumps({"error": str(e)}))
|
||||||
Loading…
Reference in New Issue
Block a user