diff --git a/scanner.py b/scanner.py index 926c992e..1e014580 100644 --- a/scanner.py +++ b/scanner.py @@ -20,7 +20,7 @@ from pathlib import Path from typing import Set, Dict, Optional # Scanner version for cache invalidation -SCANNER_VERSION = "2.0.11" # Multi-layer detection: class existence + display names +SCANNER_VERSION = "2.0.12" # Add dict comprehension + export list detection # Cache for extract_nodes and extract_nodes_enhanced results _extract_nodes_cache: Dict[str, Set[str]] = {} @@ -552,12 +552,22 @@ def extract_nodes_enhanced( if exists: phase5_nodes.add(node_name) - # Union all results (FIX: Scanner 2.0.9 bug + Scanner 2.0.10 bug) + # Phase 6: Dict comprehension pattern (NEW in 2.0.12) + # Detects: NODE_CLASS_MAPPINGS = {cls.__name__: cls for cls in to_export} + # Example: TobiasGlaubach/ComfyUI-TG_PyCode + phase6_nodes = _fallback_dict_comprehension(code_text, file_path) + + # Phase 7: Import-based class names for dict comprehension (NEW in 2.0.12) + # Detects imported classes that are added to export lists + phase7_nodes = _fallback_import_class_names(code_text, file_path) + + # Union all results (FIX: Scanner 2.0.9 bug + Scanner 2.0.10 bug + Scanner 2.0.12 dict comp) # 2.0.9: Used early return which missed Phase 3 nodes # 2.0.10: Only checked registrations, missed classes referenced in display names - all_nodes = phase1_nodes | phase2_nodes | phase3_nodes | phase4_nodes | phase5_nodes + # 2.0.12: Added dict comprehension and import-based class detection + all_nodes = phase1_nodes | phase2_nodes | phase3_nodes | phase4_nodes | phase5_nodes | phase6_nodes | phase7_nodes - # Phase 6: Empty dict detector (logging only, doesn't add nodes) + # Phase 8: Empty dict detector (logging only, doesn't add nodes) if not all_nodes: _fallback_empty_dict_detector(code_text, file_path, verbose) @@ -616,7 +626,7 @@ def _fallback_classname_resolver(code_text: str, file_path: Optional[Path]) -> S def _fallback_item_assignment(code_text: str) -> Set[str]: """ Detect item assignment pattern. - + Pattern: NODE_CLASS_MAPPINGS = {} NODE_CLASS_MAPPINGS["MyNode"] = MyNode @@ -627,9 +637,9 @@ def _fallback_item_assignment(code_text: str) -> Set[str]: parsed = ast.parse(code_text) except: return set() - + nodes = set() - + for node in ast.walk(parsed): if isinstance(node, ast.Assign): for target in node.targets: @@ -640,10 +650,156 @@ def _fallback_item_assignment(code_text: str) -> Set[str]: if isinstance(target.slice, ast.Constant): if isinstance(target.slice.value, str): nodes.add(target.slice.value) - + return nodes +def _fallback_dict_comprehension(code_text: str, file_path: Optional[Path] = None) -> Set[str]: + """ + Detect dict comprehension pattern with __name__ attribute access. + + Pattern: + NODE_CLASS_MAPPINGS = {cls.__name__: cls for cls in to_export} + NODE_CLASS_MAPPINGS = {c.__name__: c for c in [ClassA, ClassB]} + + This function detects dict comprehension assignments to NODE_CLASS_MAPPINGS + and extracts class names from the iterable (list literal or variable reference). + + Returns: + Set of class names extracted from the dict comprehension + """ + try: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=SyntaxWarning) + parsed = ast.parse(code_text) + except: + return set() + + nodes = set() + export_lists = {} # Track list variables and their contents + + # First pass: collect list assignments (to_export = [...], exports = [...]) + for node in ast.walk(parsed): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name): + var_name = target.id + # Check for list literal + if isinstance(node.value, ast.List): + class_names = set() + for elt in node.value.elts: + if isinstance(elt, ast.Name): + class_names.add(elt.id) + export_lists[var_name] = class_names + + # Handle augmented assignment: to_export += [...] + elif isinstance(node, ast.AugAssign): + if isinstance(node.target, ast.Name) and isinstance(node.op, ast.Add): + var_name = node.target.id + if isinstance(node.value, ast.List): + class_names = set() + for elt in node.value.elts: + if isinstance(elt, ast.Name): + class_names.add(elt.id) + if var_name in export_lists: + export_lists[var_name].update(class_names) + else: + export_lists[var_name] = class_names + + # Second pass: find NODE_CLASS_MAPPINGS dict comprehension + for node in ast.walk(parsed): + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id in ['NODE_CLASS_MAPPINGS', 'NODE_CONFIG']: + # Check for dict comprehension + if isinstance(node.value, ast.DictComp): + dictcomp = node.value + + # Check if key is cls.__name__ pattern + key = dictcomp.key + if isinstance(key, ast.Attribute) and key.attr == '__name__': + # Get the iterable from the first generator + for generator in dictcomp.generators: + iter_node = generator.iter + + # Case 1: Inline list [ClassA, ClassB, ...] + if isinstance(iter_node, ast.List): + for elt in iter_node.elts: + if isinstance(elt, ast.Name): + nodes.add(elt.id) + + # Case 2: Variable reference (to_export, exports, etc.) + elif isinstance(iter_node, ast.Name): + var_name = iter_node.id + if var_name in export_lists: + nodes.update(export_lists[var_name]) + + return nodes + + +def _fallback_import_class_names(code_text: str, file_path: Optional[Path] = None) -> Set[str]: + """ + Extract class names from imports that are added to export lists. + + Pattern: + from .module import ClassA, ClassB + to_export = [ClassA, ClassB] + NODE_CLASS_MAPPINGS = {cls.__name__: cls for cls in to_export} + + This is a complementary fallback that works with _fallback_dict_comprehension + to resolve import-based node registrations. + + Returns: + Set of imported class names that appear in export-like contexts + """ + try: + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=SyntaxWarning) + parsed = ast.parse(code_text) + except: + return set() + + # Collect imported names + imported_names = set() + for node in ast.walk(parsed): + if isinstance(node, ast.ImportFrom): + for alias in node.names: + name = alias.asname if alias.asname else alias.name + imported_names.add(name) + + # Check if these names appear in list assignments that feed into NODE_CLASS_MAPPINGS + export_candidates = set() + has_dict_comp_mapping = False + + for node in ast.walk(parsed): + # Check for dict comprehension NODE_CLASS_MAPPINGS + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == 'NODE_CLASS_MAPPINGS': + if isinstance(node.value, ast.DictComp): + has_dict_comp_mapping = True + + # Collect list contents + if isinstance(node, ast.Assign): + if isinstance(node.value, ast.List): + for elt in node.value.elts: + if isinstance(elt, ast.Name) and elt.id in imported_names: + export_candidates.add(elt.id) + + # Handle augmented assignment + elif isinstance(node, ast.AugAssign): + if isinstance(node.value, ast.List): + for elt in node.value.elts: + if isinstance(elt, ast.Name) and elt.id in imported_names: + export_candidates.add(elt.id) + + # Only return if there's a dict comprehension mapping + if has_dict_comp_mapping: + return export_candidates + + return set() + + def _extract_repo_name(file_path: Path) -> str: """ Extract repository name from file path.