ComfyUI/tests/sd_clip_model_init_checker.py

70 lines
2.6 KiB
Python

from typing import TYPE_CHECKING, Optional
from pylint.checkers import BaseChecker
if TYPE_CHECKING:
from pylint.lint import PyLinter
from astroid import nodes
class SDClipModelInitChecker(BaseChecker):
name = 'sd-clip-model-init-checker'
priority = -1
msgs = {
'W9001': (
'Class %s inheriting from SDClipModel must have textmodel_json_config in __init__ arguments',
'sd-clip-model-missing-config',
'Classes inheriting from comfy.sd1_clip.SDClipModel must accept textmodel_json_config in their __init__ method.',
),
}
def __init__(self, linter: Optional["PyLinter"] = None) -> None:
super().__init__(linter)
def visit_classdef(self, node: "nodes.ClassDef") -> None:
# Check if class inherits from SDClipModel
is_sd_clip_model = False
for base in node.bases:
# Check for direct name 'SDClipModel' or fully qualified 'comfy.sd1_clip.SDClipModel'
# Simple name check is usually sufficient for this targeted rule
if getattr(base, 'name', '') == 'SDClipModel':
is_sd_clip_model = True
break
if getattr(base, 'attrname', '') == 'SDClipModel': # for attribute access like module.SDClipModel
is_sd_clip_model = True
break
if not is_sd_clip_model:
return
# Check __init__ arguments
if '__init__' not in node.locals:
return # Uses parent init, assuming parent is compliant or we can't check easily
init_methods = node.locals['__init__']
if not init_methods:
return
# method could be a list of inferred values, usually we just want the definition
# node.locals returns a list of nodes for that name
init_method = init_methods[0]
if not hasattr(init_method, 'args'):
return # Might not be a function definition
args = init_method.args
arg_names = [arg.name for arg in args.args]
# Check keyword-only arguments too if present
if args.kwonlyargs:
arg_names.extend([arg.name for arg in args.kwonlyargs])
# We need check usage of *args or **kwargs?
# The prompt specifically says "have `textmodel_json_config` in the args".
# Usually this means explicit argument.
if 'textmodel_json_config' not in arg_names:
self.add_message('sd-clip-model-missing-config', args=node.name, node=node)
def register(linter: "PyLinter") -> None:
linter.register_checker(SDClipModelInitChecker(linter))