Source code for stubalyzer.compare

"""
Compare mypy types.
"""

from __future__ import annotations

from enum import Enum
from typing import Any, Dict, NamedTuple, Optional, Union

from mypy.nodes import (
    ARG_NAMED,
    ARG_POS,
    CONTRAVARIANT,
    COVARIANT,
    Decorator,
    FuncDef,
    SymbolNode,
    TypeAlias,
    TypeInfo,
    TypeVarExpr,
)
from mypy.subtypes import is_subtype
from mypy.types import CallableType, Overloaded
from mypy.types import Type as TypeNode

from .types import RelevantSymbolNode
from .utils import get_expression_fullname


[docs]class MatchResult(Enum): MATCH = "match" MISMATCH = "mismatch" NOT_FOUND = "not_found" MISLOCATED_SYMBOL = "mislocated_symbol"
[docs] @classmethod def declare_mismatch(cls, matchResultString: str) -> MatchResult: err = matchResultString == MatchResult.MATCH.value try: result = MatchResult(matchResultString) except ValueError: err = True if err: possible_values = ", ".join( [f'"{m.value}"' for m in MatchResult if m is not MatchResult.MATCH] ) raise ValueError( f'"{matchResultString}" is not a valid mismatch type.' f" (Use one of {possible_values}" ) return result
def _get_symbol_type_info(symbol: SymbolNode) -> str: """ Get the type of the given symbol as a human readable string. :param symbol: symbol for which to get the type """ if isinstance(symbol, TypeAlias): return repr(symbol.target) if isinstance(symbol, TypeVarExpr): return _format_type_var(symbol) if isinstance(symbol, TypeInfo): return f"Class({symbol.fullname})" return repr(getattr(symbol, "type", None))
[docs]class ComparisonResult(NamedTuple): """ Result of comparing two symbol nodes and their types. """ match_result: MatchResult """Type of comparison result""" symbol: RelevantSymbolNode """Symbol that was checked""" reference: Optional[SymbolNode] """Reference symbol that was checked against""" symbol_name: str """Full name of the symbol that was checked""" symbol_type: str """Type of the symbol that was checked""" reference_name: Optional[str] """Full name of the reference symbol""" reference_type: Optional[str] """Type of the reference symbol""" data: Optional[Dict[str, Any]] = None """Optional additional data""" message_val: Optional[str] = None """Optional message""" @property def message(self) -> str: """Human readable result of the comparison""" if self.message_val: return self.message_val if self.match_result is MatchResult.MATCH: return "\n".join( [ f"Types for {self.symbol_name} match:", f" {self.symbol_type}", f" {self.reference_type}", ] ) elif self.match_result is MatchResult.MISMATCH: return "\n".join( [ f"Types for {self.symbol_name} do not match:", f" Handwritten type: {self.symbol_type}", f" Reference type : {self.reference_type}", ] ) elif self.match_result is MatchResult.NOT_FOUND: return f'Symbol "{self.symbol_name}" not found in generated stubs' elif self.match_result is MatchResult.MISLOCATED_SYMBOL: return ( f'Found symbol "{self.symbol_name}" in different location' f' "{self.reference_name}".' )
[docs] @classmethod def create( cls, match_result: MatchResult, symbol: RelevantSymbolNode, reference: Optional[SymbolNode], data: Optional[Dict[str, Any]] = None, message: Optional[str] = None, ) -> ComparisonResult: """ Create a comparison result. :param match_result: if the match was successful :param symbol: symbol that was checked :param reference: reference symbol that was checked against :param data: optional additional data :param message: optional message """ return cls( match_result=match_result, symbol=symbol, reference=reference, data=data, message_val=message, symbol_name=symbol.fullname, symbol_type=_get_symbol_type_info(symbol), reference_name=reference.fullname if reference else None, reference_type=_get_symbol_type_info(reference) if reference else None, )
[docs] @classmethod def create_not_found( cls, symbol: RelevantSymbolNode, data: Optional[Dict[str, Any]] = None ) -> ComparisonResult: """ Create an unsuccessful comparison result where there was no reference symbol found. :param symbol: symbol we wanted to check :param data: optional additional data """ return cls.create( match_result=MatchResult.NOT_FOUND, symbol=symbol, reference=None, data=data )
[docs] @classmethod def create_mislocated_symbol( cls, symbol: RelevantSymbolNode, reference: SymbolNode, data: Optional[Dict[str, Any]] = None, ) -> ComparisonResult: """ Create an unsuccessful comparison result where the reference symbol was found in a different level of the class hierarchy. :param symbol: symbol we wanted to check :param reference: symbol that was found somewhere else in the hierarchy :param data: optional additional data """ return cls.create( match_result=MatchResult.MISLOCATED_SYMBOL, symbol=symbol, reference=reference, data=data, )
[docs] @classmethod def create_mismatch( cls, symbol: RelevantSymbolNode, reference: RelevantSymbolNode, data: Optional[Dict[str, Any]] = None, message: Optional[str] = None, ) -> ComparisonResult: """ Create an unsuccessful comparison result. :param symbol: symbol that was checked :param reference: reference symbol that was checked against :param data: optional additional data :param message: optional message """ return cls.create( match_result=MatchResult.MISMATCH, symbol=symbol, reference=reference, data=data, message=message, )
[docs] @classmethod def create_match( cls, symbol: RelevantSymbolNode, reference: RelevantSymbolNode, data: Optional[Dict[str, Any]] = None, message: Optional[str] = None, ) -> ComparisonResult: """ Create a successful comparison result. :param symbol: symbol that was checked :param reference: reference symbol that was checked against :param data: optional additional data :param message: optional message """ return cls.create( match_result=MatchResult.MATCH, symbol=symbol, reference=reference, data=data, message=message, )
def _mypy_types_match(symbol_type: TypeNode, reference_type: TypeNode) -> MatchResult: """ Check if the given symbol type matches the the reference type. :param symbol_type: symbol type to check :param reference_type: reference type to check against """ if is_subtype(symbol_type, reference_type): return MatchResult.MATCH return MatchResult.MISMATCH def _check_arguments_compatible( callable_type_or_func_def: Union[CallableType, FuncDef], reference: Union[CallableType, FuncDef], ) -> bool: """ Check if argument kinds and names are compatible. This is used as a fallback, if mypy does not provide type information or is not strict enough. :param callable_type_or_func_def: CallableType or FuncDef to check arg_kinds and arg_names on :param reference: reference CallableType or FuncDef to check against """ callable_args = list( zip(callable_type_or_func_def.arg_kinds, callable_type_or_func_def.arg_names) ) reference_args = list(zip(reference.arg_kinds, reference.arg_names)) positional_required_identical = [ (kind, name) for (kind, name) in callable_args if kind == ARG_POS ] == [(kind, name) for (kind, name) in reference_args if kind == ARG_POS] keyword_required_equal = { (kind, name) for (kind, name) in callable_args if kind == ARG_NAMED } == {(kind, name) for (kind, name) in reference_args if kind == ARG_NAMED} optional_args = { (kind, name) for (kind, name) in callable_args if kind not in {ARG_POS, ARG_NAMED} } optional_reference_args = { (kind, name) for (kind, name) in reference_args if kind not in {ARG_POS, ARG_NAMED} } optional_compatible = optional_args.issubset(optional_reference_args) return ( positional_required_identical and keyword_required_equal and optional_compatible ) def _callable_types_match( callable_type: CallableType, reference_callable: CallableType ) -> MatchResult: """ Check if the given callable matches the reference. :param callable_type: callable to check :param reference_callable: callable to check against """ arguments_compatible = _check_arguments_compatible( callable_type, reference_callable ) if not arguments_compatible: return MatchResult.MISMATCH return _mypy_types_match(callable_type, reference_callable) def _overloaded_types_match( overloaded: Overloaded, reference_overloaded: Overloaded ) -> MatchResult: """ Check if the given overloaded type matches the reference. :param overloaded: overloaded type to check :param reference_overloaded: overloaded type to check against """ if len(overloaded.items) != len(reference_overloaded.items): return MatchResult.MISMATCH for ovl, ref in zip(overloaded.items, reference_overloaded.items): if _callable_types_match(ovl, ref) != MatchResult.MATCH: return MatchResult.MISMATCH return MatchResult.MATCH
[docs]def compare_mypy_types( symbol: RelevantSymbolNode, reference: RelevantSymbolNode, symbol_type: Optional[TypeNode], reference_type: Optional[TypeNode], ) -> ComparisonResult: """ Check if the mypy type of given symbol node is compatible with the reference symbol. Returns a successful comparison if: - the reference type is None (this means mypy doesn't have enough information) - the symbol type is a subtype of the reference type - the symbol type overlaps with the reference type :param symbol: symbol node to validate :param reference: symbol node to validate against :param symbol_type: type of the symbol to validate :param reference_type: type of the symbol to validate against """ if reference_type is None: # The reference type will never be None for overloaded functions. if isinstance(symbol, FuncDef): assert isinstance(reference, FuncDef) arguments_compatible = _check_arguments_compatible(symbol, reference) if arguments_compatible: return ComparisonResult.create_match( symbol=symbol, reference=reference, message="Matched function definition without reference type.", ) else: return ComparisonResult.create_mismatch( symbol=symbol, reference=reference, message="Arguments don't match." ) else: # MyPy does not have enough type information # hence we accept that our stub is correct return ComparisonResult.create_match( symbol=symbol, reference=reference, message="Generated type is None" ) if symbol_type is None: return ComparisonResult.create_mismatch(symbol=symbol, reference=reference) if isinstance(symbol_type, CallableType) and isinstance( reference_type, CallableType ): match = _callable_types_match(symbol_type, reference_type) elif isinstance(symbol_type, Overloaded) and isinstance(reference_type, Overloaded): match = _overloaded_types_match(symbol_type, reference_type) else: match = _mypy_types_match(symbol_type, reference_type) return ComparisonResult.create( match_result=match, symbol=symbol, reference=reference )
def _type_infos_are_same_class( symbol: TypeInfo, reference: TypeInfo ) -> ComparisonResult: """ Check if two TypeInfo symbols are the same class. This currently only does a comparison of the full name, since we only care if the classes are defined at the same location. The instance fields and methods are usually checked individually already. :param symbol: type info symbol to validate :param reference: type info symbol to validate against """ if symbol.fullname == reference.fullname: return ComparisonResult.create_match(symbol=symbol, reference=reference) else: return ComparisonResult.create_mismatch(symbol=symbol, reference=reference) def _compare_type_aliases(symbol: TypeAlias, reference: TypeAlias) -> ComparisonResult: """ Check if a TypeAlias symbol is a valid subtype of the given reference. This is done by comparing the target types of the aliases. :param symbol: type alias symbol to validate :param reference: type alias symbol to validate against """ return compare_mypy_types(symbol, reference, symbol.target, reference.target) def _format_type_var(symbol: TypeVarExpr) -> str: """ Format a TypeVarExpr as it would be written in code. :param symbol: TypeVarExpr to format """ variance = "" if symbol.variance == COVARIANT: variance = ", covariant=True" elif symbol.variance == CONTRAVARIANT: variance = ", contravariant=True" values = "" if symbol.values: values = ", " + (", ".join(str(t) for t in symbol.values)) return f"{symbol.name} = TypeVar('{symbol.name}'{values}{variance})" def _match_type_var_expr(symbol: TypeVarExpr, reference: TypeVarExpr) -> MatchResult: if symbol.variance != reference.variance: return MatchResult.MISMATCH if not symbol.values and not reference.values: return compare_mypy_types( symbol, reference, symbol.upper_bound, reference.upper_bound ).match_result if len(symbol.values) != len(reference.values): return MatchResult.MISMATCH for symbol_type, reference_type in zip(symbol.values, reference.values): result = compare_mypy_types(symbol, reference, symbol_type, reference_type) if result.match_result is not MatchResult.MATCH: return result.match_result return MatchResult.MATCH def _compare_type_var_expr( symbol: TypeVarExpr, reference: TypeVarExpr ) -> ComparisonResult: """ Check if a TypeVarExpr symbol matches the reference. :param symbol: type var symbol to validate :param reference: type var symbol to validate against """ return ComparisonResult.create( match_result=_match_type_var_expr(symbol, reference), symbol=symbol, reference=reference, ) def _compare_decorator(symbol: Decorator, reference: Decorator) -> ComparisonResult: """ Check if Decorator symbol matches the reference Returns a successful comparison if: - all decorators are the same and applied in the same order, - the function these decorators are applied to match :param symbol: decorator symbol to validate :param reference: decorator symbol to validate against """ symbol_decorators = list(map(get_expression_fullname, symbol.original_decorators)) reference_decorators = list( map(get_expression_fullname, reference.original_decorators) ) if symbol_decorators == reference_decorators: function_comparison = compare_symbols(symbol.func, reference.func) return ComparisonResult.create( match_result=function_comparison.match_result, symbol=symbol, reference=reference, data=function_comparison.data, message=function_comparison.message, ) else: return ComparisonResult.create_mismatch( symbol=symbol, reference=reference, data={ "Symbol decorators": symbol_decorators, "Reference decorators": reference_decorators, }, message=( f"Function {symbol.func.fullname} stubs have different decorators." ), )
[docs]def compare_symbols( symbol: RelevantSymbolNode, reference: RelevantSymbolNode ) -> ComparisonResult: """ Check if the given symbol node is compatible with the reference symbol. Will return a successful comparison if any of the following holds: - the symbols describe the same class - the symbols are type aliases that resolve to the same type - ``symbol`` is a valid subtype of ``reference`` (see :py:func:`mypy.subtypes.is_subtype`) - ``symbol`` and ``reference`` somehow overlap (see :py:func:`mypy.meet.is_overlapping_types`) :param symbol: symbol node to validate :param reference: symbol node to validate against """ # TODO: Check if this is always the case, i.e. could there be # cases where `symbol` and `reference` don't have the same class but still match? if type(symbol) != type(reference): return ComparisonResult.create_mismatch(symbol=symbol, reference=reference) if isinstance(symbol, TypeInfo) and isinstance(reference, TypeInfo): return _type_infos_are_same_class(symbol, reference) if isinstance(symbol, TypeAlias) and isinstance(reference, TypeAlias): return _compare_type_aliases(symbol, reference) if isinstance(symbol, TypeVarExpr) and isinstance(reference, TypeVarExpr): return _compare_type_var_expr(symbol, reference) if isinstance(symbol, Decorator) and isinstance(reference, Decorator): return _compare_decorator(symbol, reference) return compare_mypy_types( symbol, reference, getattr(symbol, "type"), getattr(reference, "type") )