# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

""" An attempt at type-safe casting.
"""

from inspect import getfullargspec, getmro
import logging
from types import FunctionType

from .has_traits import HasTraits

logger = logging.getLogger(__name__)


# Constants:

# Message templates for interface errors.
BAD_SIGNATURE = (
    "The '%s' class signature for the '%s' method is different "
    "from that of the '%s' interface."
)
MISSING_METHOD = (
    "The '%s' class does not implement the '%s' method of the "
    "'%s' interface."
)
MISSING_TRAIT = (
    "The '%s' class does not implement the %s trait(s) of the "
    "'%s' interface."
)


class InterfaceError(Exception):
    """ The exception raised if a class does not really implement an interface.
    """

    pass


class InterfaceChecker(HasTraits):
    """ Checks that interfaces are actually implemented.
    """

    def check_implements(self, cls, interfaces, error_mode):
        """ Checks that the class implements the specified interfaces.

            'interfaces' can be a single interface or a list of interfaces.
        """
        # If a single interface was specified then turn it into a list:
        try:
            iter(interfaces)
        except TypeError:
            interfaces = [interfaces]

        # If the class has traits then check that it implements all traits and
        # methods on the specified interfaces:
        if issubclass(cls, HasTraits):
            for interface in interfaces:
                if not self._check_has_traits_class(
                    cls, interface, error_mode
                ):
                    return False

        # Otherwise, just check that the class implements all methods on the
        # specified interface:
        else:
            for interface in interfaces:
                if not self._check_non_has_traits_class(
                    cls, interface, error_mode
                ):
                    return False

        return True

    def _check_has_traits_class(self, cls, interface, error_mode):
        """ Checks that a 'HasTraits' class implements an interface.
        """
        return self._check_traits(
            cls, interface, error_mode
        ) and self._check_methods(cls, interface, error_mode)

    def _check_non_has_traits_class(self, cls, interface, error_mode):
        """ Checks that a non-'HasTraits' class implements an interface.
        """
        return self._check_methods(cls, interface, error_mode)

    def _check_methods(self, cls, interface, error_mode):
        """ Checks that a class implements the methods on an interface.
        """
        cls_methods = self._get_public_methods(cls)
        interface_methods = self._get_public_methods(interface)

        for name in interface_methods:
            if name not in cls_methods:
                return self._handle_error(
                    MISSING_METHOD
                    % (
                        self._class_name(cls),
                        name,
                        self._class_name(interface),
                    ),
                    error_mode,
                )

            # Check that the method signatures are the same:
            cls_argspec = getfullargspec(cls_methods[name])
            interface_argspec = getfullargspec(interface_methods[name])

            if cls_argspec != interface_argspec:
                return self._handle_error(
                    BAD_SIGNATURE
                    % (
                        self._class_name(cls),
                        name,
                        self._class_name(interface),
                    ),
                    error_mode,
                )

        return True

    def _check_traits(self, cls, interface, error_mode):
        """ Checks that a class implements the traits on an interface.
        """
        missing = set(interface.class_traits()).difference(
            set(cls.class_traits())
        )

        if len(missing) > 0:
            return self._handle_error(
                MISSING_TRAIT
                % (
                    self._class_name(cls),
                    repr(list(missing))[1:-1],
                    self._class_name(interface),
                ),
                error_mode,
            )

        return True

    def _get_public_methods(self, cls):
        """ Returns all public methods on a class.

            Returns a dictionary containing all public methods keyed by name.
        """
        public_methods = {}
        for c in getmro(cls):
            # Stop when we get to 'HasTraits'!:
            if c is HasTraits:
                break

            for name, value in c.__dict__.items():
                if (not name.startswith("_")) and (
                    type(value) is FunctionType
                ):
                    if name not in public_methods:
                        public_methods[name] = value

        return public_methods

    def _class_name(self, cls):
        return cls.__name__

    def _handle_error(self, msg, error_mode):
        if error_mode > 1:
            raise InterfaceError(msg)

        if error_mode == 1:
            logger.warning(msg)

        return False


# A default interface checker:
checker = InterfaceChecker()


def check_implements(cls, interfaces, error_mode=0):
    """ Checks that the class implements the specified interfaces.

        'interfaces' can be a single interface or a list of interfaces.
    """
    return checker.check_implements(cls, interfaces, error_mode)
