mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
feat: add first 3 research nodes (PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble)
This commit is contained in:
parent
2199e56581
commit
be9241aa0f
17
custom_nodes/research/__init__.py
Normal file
17
custom_nodes/research/__init__.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Research Workbench custom nodes for ComfyUI."""
|
||||
from typing import List
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class ResearchExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> List[type[io.ComfyNode]]:
|
||||
from custom_nodes.research.paper_search import PaperSearch
|
||||
from custom_nodes.research.claim_extract import PaperClaimExtract
|
||||
from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble
|
||||
return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ComfyExtension:
|
||||
return ResearchExtension()
|
||||
68
custom_nodes/research/claim_extract.py
Normal file
68
custom_nodes/research/claim_extract.py
Normal file
@ -0,0 +1,68 @@
|
||||
"""PaperClaimExtract node - extract claims from paper text."""
|
||||
import json
|
||||
import re
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyNode, io
|
||||
|
||||
|
||||
class PaperClaimExtract(io.ComfyNode):
|
||||
"""Extract scientific claims from paper text or abstract."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="PaperClaimExtract",
|
||||
display_name="Extract Claims",
|
||||
category="Research",
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"paper_text",
|
||||
display_name="Paper Text / Abstract",
|
||||
default="",
|
||||
multiline=True,
|
||||
),
|
||||
io.String.Input(
|
||||
"claim_types",
|
||||
display_name="Claim Types (comma-separated)",
|
||||
default="performance,robustness,generalization",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="Claims (JSON)"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, paper_text: str, claim_types: str) -> io.NodeOutput:
|
||||
if not paper_text.strip():
|
||||
return io.NodeOutput(claims=json.dumps([]))
|
||||
|
||||
claim_type_list = [t.strip() for t in claim_types.split(",")]
|
||||
sentences = re.split(r'(?<=[.!?])\s+', paper_text.strip())
|
||||
|
||||
claims = []
|
||||
for sent in sentences:
|
||||
sent = sent.strip()
|
||||
if not sent or len(sent) < 20:
|
||||
continue
|
||||
lower = sent.lower()
|
||||
# Heuristic: sentences with comparison/performance words
|
||||
if any(w in lower for w in ["achieves", "outperforms", "improves", "reduces", "increases", "demonstrates", "shows", "provides", "enables"]):
|
||||
claims.append({
|
||||
"text": sent,
|
||||
"type": _classify_claim(sent, claim_type_list),
|
||||
"support_level": "unsupported",
|
||||
})
|
||||
|
||||
return io.NodeOutput(claims=json.dumps(claims[:10], indent=2)) # limit to 10
|
||||
|
||||
|
||||
def _classify_claim(sentence: str, claim_types) -> str:
|
||||
lower = sentence.lower()
|
||||
if any(w in lower for w in ["accuracy", "performance", "score", "auc", "f1", "precision", "recall"]):
|
||||
return "performance"
|
||||
if any(w in lower for w in ["robust", "noise", "attack", "adversarial", "corruption"]):
|
||||
return "robustness"
|
||||
if any(w in lower for w in ["general", "transfer", "cross-domain", "cross-dataset"]):
|
||||
return "generalization"
|
||||
return claim_types[0] if claim_types else "other"
|
||||
54
custom_nodes/research/evidence_assemble.py
Normal file
54
custom_nodes/research/evidence_assemble.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""ClaimEvidenceAssemble node - assemble claim × evidence matrix."""
|
||||
import json
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyNode, io
|
||||
|
||||
|
||||
class ClaimEvidenceAssemble(io.ComfyNode):
|
||||
"""Assemble claims and papers into a structured evidence matrix."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ClaimEvidenceAssemble",
|
||||
display_name="Assemble Evidence",
|
||||
category="Research",
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"claims",
|
||||
display_name="Claims (JSON)",
|
||||
default="[]",
|
||||
multiline=True,
|
||||
),
|
||||
io.String.Input(
|
||||
"papers",
|
||||
display_name="Papers (JSON)",
|
||||
default="[]",
|
||||
multiline=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="Evidence Matrix (JSON)"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, claims: str, papers: str) -> io.NodeOutput:
|
||||
try:
|
||||
claims_list = json.loads(claims) if claims else []
|
||||
papers_list = json.loads(papers) if papers else []
|
||||
|
||||
matrix = []
|
||||
for claim in claims_list:
|
||||
claim_text = claim.get("text", "") if isinstance(claim, dict) else str(claim)
|
||||
matrix.append({
|
||||
"claim": claim_text,
|
||||
"claim_type": claim.get("type", "") if isinstance(claim, dict) else "",
|
||||
"support_level": claim.get("support_level", "unsupported") if isinstance(claim, dict) else "unsupported",
|
||||
"evidence": [],
|
||||
"gap_flags": ["No linked evidence yet"],
|
||||
})
|
||||
|
||||
return io.NodeOutput(evidence_matrix=json.dumps(matrix, indent=2))
|
||||
except json.JSONDecodeError as e:
|
||||
return io.NodeOutput(evidence_matrix=json.dumps({"error": f"JSON parse error: {e}"}))
|
||||
62
custom_nodes/research/paper_search.py
Normal file
62
custom_nodes/research/paper_search.py
Normal file
@ -0,0 +1,62 @@
|
||||
"""PaperSearch node - search papers via academic APIs."""
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyNode, io
|
||||
|
||||
|
||||
class PaperSearch(io.ComfyNode):
|
||||
"""Search academic papers from Semantic Scholar."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="PaperSearch",
|
||||
display_name="Paper Search",
|
||||
category="Research",
|
||||
inputs=[
|
||||
io.String.Input(
|
||||
"query",
|
||||
display_name="Search Query",
|
||||
default="",
|
||||
),
|
||||
io.Int.Input(
|
||||
"max_results",
|
||||
display_name="Max Results",
|
||||
default=5,
|
||||
min=1,
|
||||
max=20,
|
||||
step=1,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="Papers (JSON)"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, query: str, max_results: int) -> io.NodeOutput:
|
||||
if not query.strip():
|
||||
return io.NodeOutput(papers=json.dumps([]))
|
||||
|
||||
encoded_query = urllib.parse.quote(query)
|
||||
url = f"https://api.semanticscholar.org/graph/v1/paper/search?query={encoded_query}&limit={max_results}&fields=title,authors,abstract,year,journal,venue"
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={"Accept": "application/json"})
|
||||
with urllib.request.urlopen(req, timeout=15) as response:
|
||||
data = json.loads(response.read().decode())
|
||||
papers = []
|
||||
for item in data.get("data", []):
|
||||
papers.append({
|
||||
"title": item.get("title", ""),
|
||||
"authors": [a.get("name", "") for a in item.get("authors", [])],
|
||||
"abstract": item.get("abstract", ""),
|
||||
"year": item.get("year", ""),
|
||||
"venue": item.get("venue", ""),
|
||||
"paper_id": item.get("paperId", ""),
|
||||
})
|
||||
return io.NodeOutput(papers=json.dumps(papers, indent=2))
|
||||
except Exception as e:
|
||||
return io.NodeOutput(papers=json.dumps({"error": str(e)}))
|
||||
Loading…
Reference in New Issue
Block a user