diff --git a/scanner.py b/scanner.py index dcd9e767..69539af9 100644 --- a/scanner.py +++ b/scanner.py @@ -105,10 +105,17 @@ def extract_nodes(code_text): warnings.filterwarnings('ignore', category=DeprecationWarning) parsed_code = ast.parse(code_text) - assignments = (node for node in parsed_code.body if isinstance(node, ast.Assign)) + # Support both ast.Assign and ast.AnnAssign (for type-annotated assignments) + assignments = (node for node in parsed_code.body if isinstance(node, (ast.Assign, ast.AnnAssign))) for assignment in assignments: - if isinstance(assignment.targets[0], ast.Name) and assignment.targets[0].id in ['NODE_CONFIG', 'NODE_CLASS_MAPPINGS']: + # Handle ast.AnnAssign (e.g., NODE_CLASS_MAPPINGS: Type = {...}) + if isinstance(assignment, ast.AnnAssign): + if isinstance(assignment.target, ast.Name) and assignment.target.id in ['NODE_CONFIG', 'NODE_CLASS_MAPPINGS']: + node_class_mappings = assignment.value + break + # Handle ast.Assign (e.g., NODE_CLASS_MAPPINGS = {...}) + elif isinstance(assignment.targets[0], ast.Name) and assignment.targets[0].id in ['NODE_CONFIG', 'NODE_CLASS_MAPPINGS']: node_class_mappings = assignment.value break else: @@ -228,7 +235,8 @@ def scan_in_file(filename, is_builtin=False): with open(filename, encoding='utf-8', errors='ignore') as file: code = file.read() - pattern = r"_CLASS_MAPPINGS\s*=\s*{([^}]*)}" + # Support type annotations (e.g., NODE_CLASS_MAPPINGS: Type = {...}) and line continuations (\) + pattern = r"_CLASS_MAPPINGS\s*(?::\s*\w+\s*)?=\s*(?:\\\s*)?{([^}]*)}" regex = re.compile(pattern, re.MULTILINE | re.DOTALL) nodes = set()