diff --git a/custom_nodes/research/__init__.py b/custom_nodes/research/__init__.py new file mode 100644 index 000000000..24daf8e20 --- /dev/null +++ b/custom_nodes/research/__init__.py @@ -0,0 +1,17 @@ +"""Research Workbench custom nodes for ComfyUI.""" +from typing import List +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ResearchExtension(ComfyExtension): + @override + async def get_node_list(self) -> List[type[io.ComfyNode]]: + from custom_nodes.research.paper_search import PaperSearch + from custom_nodes.research.claim_extract import PaperClaimExtract + from custom_nodes.research.evidence_assemble import ClaimEvidenceAssemble + return [PaperSearch, PaperClaimExtract, ClaimEvidenceAssemble] + + +async def comfy_entrypoint() -> ComfyExtension: + return ResearchExtension() diff --git a/custom_nodes/research/claim_extract.py b/custom_nodes/research/claim_extract.py new file mode 100644 index 000000000..5b56cb3e6 --- /dev/null +++ b/custom_nodes/research/claim_extract.py @@ -0,0 +1,68 @@ +"""PaperClaimExtract node - extract claims from paper text.""" +import json +import re +from typing_extensions import override +from comfy_api.latest import ComfyNode, io + + +class PaperClaimExtract(io.ComfyNode): + """Extract scientific claims from paper text or abstract.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="PaperClaimExtract", + display_name="Extract Claims", + category="Research", + inputs=[ + io.String.Input( + "paper_text", + display_name="Paper Text / Abstract", + default="", + multiline=True, + ), + io.String.Input( + "claim_types", + display_name="Claim Types (comma-separated)", + default="performance,robustness,generalization", + ), + ], + outputs=[ + io.String.Output(display_name="Claims (JSON)"), + ], + ) + + @classmethod + def execute(cls, paper_text: str, claim_types: str) -> io.NodeOutput: + if not paper_text.strip(): + return io.NodeOutput(claims=json.dumps([])) + + claim_type_list = [t.strip() for t in claim_types.split(",")] + sentences = re.split(r'(?<=[.!?])\s+', paper_text.strip()) + + claims = [] + for sent in sentences: + sent = sent.strip() + if not sent or len(sent) < 20: + continue + lower = sent.lower() + # Heuristic: sentences with comparison/performance words + if any(w in lower for w in ["achieves", "outperforms", "improves", "reduces", "increases", "demonstrates", "shows", "provides", "enables"]): + claims.append({ + "text": sent, + "type": _classify_claim(sent, claim_type_list), + "support_level": "unsupported", + }) + + return io.NodeOutput(claims=json.dumps(claims[:10], indent=2)) # limit to 10 + + +def _classify_claim(sentence: str, claim_types) -> str: + lower = sentence.lower() + if any(w in lower for w in ["accuracy", "performance", "score", "auc", "f1", "precision", "recall"]): + return "performance" + if any(w in lower for w in ["robust", "noise", "attack", "adversarial", "corruption"]): + return "robustness" + if any(w in lower for w in ["general", "transfer", "cross-domain", "cross-dataset"]): + return "generalization" + return claim_types[0] if claim_types else "other" diff --git a/custom_nodes/research/evidence_assemble.py b/custom_nodes/research/evidence_assemble.py new file mode 100644 index 000000000..0cb9cc7ea --- /dev/null +++ b/custom_nodes/research/evidence_assemble.py @@ -0,0 +1,54 @@ +"""ClaimEvidenceAssemble node - assemble claim × evidence matrix.""" +import json +from typing_extensions import override +from comfy_api.latest import ComfyNode, io + + +class ClaimEvidenceAssemble(io.ComfyNode): + """Assemble claims and papers into a structured evidence matrix.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ClaimEvidenceAssemble", + display_name="Assemble Evidence", + category="Research", + inputs=[ + io.String.Input( + "claims", + display_name="Claims (JSON)", + default="[]", + multiline=True, + ), + io.String.Input( + "papers", + display_name="Papers (JSON)", + default="[]", + multiline=True, + ), + ], + outputs=[ + io.String.Output(display_name="Evidence Matrix (JSON)"), + ], + ) + + @classmethod + def execute(cls, claims: str, papers: str) -> io.NodeOutput: + try: + claims_list = json.loads(claims) if claims else [] + papers_list = json.loads(papers) if papers else [] + + matrix = [] + for claim in claims_list: + claim_text = claim.get("text", "") if isinstance(claim, dict) else str(claim) + matrix.append({ + "claim": claim_text, + "claim_type": claim.get("type", "") if isinstance(claim, dict) else "", + "support_level": claim.get("support_level", "unsupported") if isinstance(claim, dict) else "unsupported", + "evidence": [], + "gap_flags": ["No linked evidence yet"], + }) + + return io.NodeOutput(evidence_matrix=json.dumps(matrix, indent=2)) + except json.JSONDecodeError as e: + return io.NodeOutput(evidence_matrix=json.dumps({"error": f"JSON parse error: {e}"})) diff --git a/custom_nodes/research/paper_search.py b/custom_nodes/research/paper_search.py new file mode 100644 index 000000000..aa78e7198 --- /dev/null +++ b/custom_nodes/research/paper_search.py @@ -0,0 +1,62 @@ +"""PaperSearch node - search papers via academic APIs.""" +import json +import urllib.request +import urllib.parse +from typing_extensions import override +from comfy_api.latest import ComfyNode, io + + +class PaperSearch(io.ComfyNode): + """Search academic papers from Semantic Scholar.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="PaperSearch", + display_name="Paper Search", + category="Research", + inputs=[ + io.String.Input( + "query", + display_name="Search Query", + default="", + ), + io.Int.Input( + "max_results", + display_name="Max Results", + default=5, + min=1, + max=20, + step=1, + ), + ], + outputs=[ + io.String.Output(display_name="Papers (JSON)"), + ], + ) + + @classmethod + def execute(cls, query: str, max_results: int) -> io.NodeOutput: + if not query.strip(): + return io.NodeOutput(papers=json.dumps([])) + + encoded_query = urllib.parse.quote(query) + url = f"https://api.semanticscholar.org/graph/v1/paper/search?query={encoded_query}&limit={max_results}&fields=title,authors,abstract,year,journal,venue" + + try: + req = urllib.request.Request(url, headers={"Accept": "application/json"}) + with urllib.request.urlopen(req, timeout=15) as response: + data = json.loads(response.read().decode()) + papers = [] + for item in data.get("data", []): + papers.append({ + "title": item.get("title", ""), + "authors": [a.get("name", "") for a in item.get("authors", [])], + "abstract": item.get("abstract", ""), + "year": item.get("year", ""), + "venue": item.get("venue", ""), + "paper_id": item.get("paperId", ""), + }) + return io.NodeOutput(papers=json.dumps(papers, indent=2)) + except Exception as e: + return io.NodeOutput(papers=json.dumps({"error": str(e)}))