# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

import copy
import pickle
import unittest

from traits.has_traits import (
    update_traits_class_dict,
    on_trait_change,
    BaseTraits,
    ClassTraits,
    PrefixTraits,
    ListenerTraits,
    InstanceTraits,
    HasTraits,
    observe,
    ObserverTraits,
    SingletonHasTraits,
    SingletonHasStrictTraits,
    SingletonHasPrivateTraits,
    MetaHasTraits,
)
from traits.ctrait import CTrait
from traits.observation.api import (
    compile_str,
    compile_expr,
    NotifierNotFound,
    trait,
)
from traits.observation.exception_handling import (
    pop_exception_handler,
    push_exception_handler,
)
from traits.traits import ForwardProperty, generic_trait
from traits.trait_types import Event, Float, Instance, Int, List, Map, Str
from traits.trait_errors import TraitError


def _dummy_getter(self):
    pass


def _dummy_setter(self, value):
    pass


def _dummy_validator(self, value):
    pass


class TestMetaHasTraits(unittest.TestCase):
    def test_add_listener_and_remove_listener_deprecated(self):
        def listener(cls):
            pass

        with self.assertWarnsRegex(
            DeprecationWarning, "add_listener is deprecated"
        ):
            MetaHasTraits.add_listener(listener)
        with self.assertWarnsRegex(
            DeprecationWarning, "remove_listener is deprecated"
        ):
            MetaHasTraits.remove_listener(listener)


class TestCreateTraitsMetaDict(unittest.TestCase):
    def test_class_attributes(self):
        # Given
        class_name = "MyClass"
        bases = (object,)
        class_dict = {"attr": "something"}

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then; Check that the original Python-level class attributes are still
        # present in the class dictionary.
        self.assertEqual(class_dict["attr"], "something")

        # Other traits dictionaries should be empty.
        for kind in (BaseTraits, ClassTraits, ListenerTraits, InstanceTraits):
            self.assertEqual(class_dict[kind], {})

    def test_forward_property(self):
        # Given
        class_name = "MyClass"
        bases = (object,)
        class_dict = {
            "attr": "something",
            "my_property": ForwardProperty({}),
            "_get_my_property": _dummy_getter,
            "_set_my_property": _dummy_setter,
            "_validate_my_property": _dummy_validator,
        }

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(class_dict[ListenerTraits], {})
        self.assertEqual(class_dict[InstanceTraits], {})

        # Both ClassTraits and BaseTraits should contain a single trait (the
        # property we created)
        self.assertEqual(len(class_dict[BaseTraits]), 1)
        self.assertEqual(len(class_dict[ClassTraits]), 1)
        self.assertIs(
            class_dict[BaseTraits]["my_property"],
            class_dict[ClassTraits]["my_property"],
        )

        # The class_dict should still have the entry for `attr`, but not
        # `my_property`.
        self.assertEqual(class_dict["attr"], "something")
        self.assertNotIn("my_property", class_dict)

    def test_standard_trait(self):
        # Given
        class_name = "MyClass"
        bases = (object,)
        class_dict = {"attr": "something", "my_int": Int}

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(class_dict[ListenerTraits], {})
        self.assertEqual(class_dict[InstanceTraits], {})

        # Both ClassTraits and BaseTraits should contain a single trait (the
        # Int trait)
        self.assertEqual(len(class_dict[BaseTraits]), 1)
        self.assertEqual(len(class_dict[ClassTraits]), 1)
        self.assertIs(
            class_dict[BaseTraits]["my_int"], class_dict[ClassTraits]["my_int"]
        )

        # The class_dict should still have the entry for `attr`, but not
        # `my_int`.
        self.assertEqual(class_dict["attr"], "something")
        self.assertNotIn("my_int", class_dict)

    def test_prefix_trait(self):
        # Given
        class_name = "MyClass"
        bases = (object,)
        class_dict = {"attr": "something", "my_int_": Int}  # prefix trait

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        for kind in (BaseTraits, ClassTraits, ListenerTraits, InstanceTraits):
            self.assertEqual(class_dict[kind], {})

        self.assertIn("my_int", class_dict[PrefixTraits])

        # The class_dict should still have the entry for `attr`, but not
        # `my_int`.
        self.assertEqual(class_dict["attr"], "something")
        self.assertNotIn("my_int", class_dict)

    def test_listener_trait(self):
        # Given
        @on_trait_change("something")
        def listener(self):
            pass

        class_name = "MyClass"
        bases = (object,)
        class_dict = {"attr": "something", "my_listener": listener}

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(class_dict[BaseTraits], {})
        self.assertEqual(class_dict[ClassTraits], {})
        self.assertEqual(class_dict[InstanceTraits], {})
        self.assertEqual(
            class_dict[ListenerTraits],
            {
                "my_listener": (
                    "method",
                    {
                        "pattern": "something",
                        "post_init": False,
                        "dispatch": "same",
                    },
                )
            },
        )

    def test_observe_trait(self):
        # Given
        @observe(trait("value"), post_init=True, dispatch="ui")
        @observe("name")
        def handler(self, event):
            pass

        class_name = "MyClass"
        bases = (object,)
        class_dict = {"attr": "something", "my_listener": handler}

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(
            class_dict[ObserverTraits],
            {
                "my_listener": [
                    {
                        "graphs": compile_str("name"),
                        "post_init": False,
                        "dispatch": "same",
                        "handler_getter": getattr,
                    },
                    {
                        "graphs": compile_expr(trait("value")),
                        "post_init": True,
                        "dispatch": "ui",
                        "handler_getter": getattr,
                    },
                ],
            },
        )

    def test_python_property(self):
        # Given
        class_name = "MyClass"
        bases = (object,)
        class_dict = {
            "attr": "something",
            "my_property": property(_dummy_getter),
        }

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(class_dict[BaseTraits], {})
        self.assertEqual(class_dict[InstanceTraits], {})
        self.assertEqual(class_dict[ListenerTraits], {})
        self.assertIs(class_dict[ClassTraits]["my_property"], generic_trait)

    def test_complex_baseclass(self):
        # Given
        class Base(HasTraits):
            x = Int

        class_name = "MyClass"
        bases = (Base,)
        class_dict = {"attr": "something", "my_trait": Float()}

        # When
        update_traits_class_dict(class_name, bases, class_dict)

        # Then
        self.assertEqual(class_dict[InstanceTraits], {})
        self.assertEqual(class_dict[ListenerTraits], {})
        self.assertIs(
            class_dict[BaseTraits]["x"], class_dict[ClassTraits]["x"]
        )
        self.assertIs(
            class_dict[BaseTraits]["my_trait"],
            class_dict[ClassTraits]["my_trait"],
        )


class TestHasTraits(unittest.TestCase):
    def test__class_traits(self):
        # Exercise the _class_traits() private introspection method.
        class Base(HasTraits):
            pin = Int

        a = Base()
        a_class_traits = a._class_traits()
        self.assertIsInstance(a_class_traits, dict)
        self.assertIn("pin", a_class_traits)
        self.assertIsInstance(a_class_traits["pin"], CTrait)

        b = Base()
        self.assertIs(b._class_traits(), a_class_traits)

    def test__instance_traits(self):
        # Exercise the _instance_traits() private introspection method.
        class Base(HasTraits):
            pin = Int

        a = Base()
        a_instance_traits = a._instance_traits()
        self.assertIsInstance(a_instance_traits, dict)

        # A second call should return the same dictionary.
        self.assertIs(a._instance_traits(), a_instance_traits)

        # A different instance should have its own instance traits dict.
        b = Base()
        self.assertIsNot(b._instance_traits(), a_instance_traits)

    def test__trait_notifications_enabled(self):
        class Base(HasTraits):
            foo = Int(0)

            foo_notify_count = Int(0)

            def _foo_changed(self):
                self.foo_notify_count += 1

        a = Base()

        # Default state is that notifications are enabled.
        self.assertTrue(a._trait_notifications_enabled())

        # Changing foo increments the count.
        old_count = a.foo_notify_count
        a.foo += 1
        self.assertEqual(a.foo_notify_count, old_count + 1)

        # After disabling notifications, count is not increased.
        a._trait_change_notify(False)
        self.assertFalse(a._trait_notifications_enabled())
        old_count = a.foo_notify_count
        a.foo += 1
        self.assertEqual(a.foo_notify_count, old_count)

        # After re-enabling notifications, count is increased.
        a._trait_change_notify(True)
        self.assertTrue(a._trait_notifications_enabled())
        old_count = a.foo_notify_count
        a.foo += 1
        self.assertEqual(a.foo_notify_count, old_count + 1)

    def test__trait_notifications_vetoed(self):
        class SomeEvent(HasTraits):
            event_id = Int()

        class Target(HasTraits):
            event = Event(Instance(SomeEvent))

            event_count = Int(0)

            def _event_fired(self):
                self.event_count += 1

        target = Target()
        event = SomeEvent(event_id=1234)

        # Default state is not vetoed.
        self.assertFalse(event._trait_notifications_vetoed())

        # Firing the event increments the count.
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count + 1)

        # Now veto the event. Firing the event won't affect the count.
        event._trait_veto_notify(True)
        self.assertTrue(event._trait_notifications_vetoed())
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count)

        # Unveto the event.
        event._trait_veto_notify(False)
        self.assertFalse(event._trait_notifications_vetoed())
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count + 1)

    def test__object_notifiers_vetoed(self):

        class SomeEvent(HasTraits):
            event_id = Int()

        class Target(HasTraits):
            event = Event(Instance(SomeEvent))

            event_count = Int(0)

        target = Target()
        event = SomeEvent(event_id=9)

        def object_handler(object, name, old, new):
            if name == "event":
                object.event_count += 1

        target.on_trait_change(object_handler, name="anytrait")

        # Default state is not vetoed.
        self.assertFalse(event._trait_notifications_vetoed())

        # Firing the event increments the count.
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count + 1)

        # Now veto the event. Firing the event won't affect the count.
        event._trait_veto_notify(True)
        self.assertTrue(event._trait_notifications_vetoed())
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count)

        # Unveto the event.
        event._trait_veto_notify(False)
        self.assertFalse(event._trait_notifications_vetoed())
        old_count = target.event_count
        target.event = event
        self.assertEqual(target.event_count, old_count + 1)

    def test_traits_inited(self):
        foo = HasTraits()

        self.assertTrue(foo.traits_inited())

    def test__trait_set_inited(self):
        foo = HasTraits.__new__(HasTraits)

        self.assertFalse(foo.traits_inited())

        foo._trait_set_inited()

        self.assertTrue(foo.traits_inited())

    def test_generic_getattr_exception(self):
        # Regression test for enthought/traits#946.

        class PropertyLike:
            """
            Data descriptor giving a property-like object that produces
            successive reciprocals on __get__. This means that it raises
            on first access, but not on subsequent accesses.
            """
            def __init__(self):
                self.n = 0

            def __get__(self, obj, type=None):
                old_n = self.n
                self.n += 1
                return 1 / old_n

            # Need a __set__ method to make this a data descriptor.
            def __set__(self, obj, value):
                raise AttributeError("Read-only descriptor")

        class A(HasTraits):
            fruit = PropertyLike()

            banana_ = Int(1729)

        a = A()

        # The exception raised on the first attribute access should be
        # propagated.
        with self.assertRaises(ZeroDivisionError):
            a.fruit

        # Exercise the code path where the PyObject_GenericGetAttr call raises
        # AttributeError. In this case, we catch the error but the prefix trait
        # machinery raises a new AttributeError.
        with self.assertRaises(AttributeError):
            a.veg

        # Exercise the case where the prefix traits machinery goes on to
        # produce a valid result.
        self.assertEqual(a.banananana, 1729)

    def test_deepcopy_memoization(self):
        class A(HasTraits):
            x = Int()
            y = Str()

        a = A()
        objs = [a, a]
        objs_copy = copy.deepcopy(objs)
        self.assertIsNot(objs_copy[0], objs[0])
        self.assertIs(objs_copy[0], objs_copy[1])

    def test_add_class_trait(self):
        # Testing basic usage.
        class A(HasTraits):
            pass

        A.add_class_trait("y", Str())

        a = A()

        self.assertEqual(a.y, "")

    def test_add_class_trait_affects_existing_instances(self):
        class A(HasTraits):
            pass

        a = A()

        A.add_class_trait("y", Str())

        self.assertEqual(a.y, "")

    def test_add_class_trait_affects_subclasses(self):
        class A(HasTraits):
            pass

        class B(A):
            pass

        class C(B):
            pass

        class D(B):
            pass

        A.add_class_trait("y", Str())
        self.assertEqual(A().y, "")
        self.assertEqual(B().y, "")
        self.assertEqual(C().y, "")
        self.assertEqual(D().y, "")

    def test_add_class_trait_has_items_and_subclasses(self):
        # Regression test for enthought/traits#1460
        class A(HasTraits):
            pass

        class B(A):
            pass

        class C(B):
            pass

        # Code branch for traits with items.
        A.add_class_trait("x", List(Int))
        self.assertEqual(A().x, [])
        self.assertEqual(B().x, [])
        self.assertEqual(C().x, [])

        # Exercise the code branch for mapped traits.
        A.add_class_trait("y", Map({"yes": 1, "no": 0}, default_value="no"))
        self.assertEqual(A().y, "no")
        self.assertEqual(B().y, "no")
        self.assertEqual(C().y, "no")

    def test_add_class_trait_add_prefix_traits(self):

        class A(HasTraits):
            pass

        A.add_class_trait("abc_", Str())
        A.add_class_trait("abc_def_", Int())

        a = A()
        self.assertEqual(a.abc_def_g, 0)
        self.assertEqual(a.abc_z, "")

    def test_add_class_trait_when_trait_already_exists(self):

        class A(HasTraits):
            foo = Int()

        with self.assertRaises(TraitError):
            A.add_class_trait("foo", List())

        self.assertEqual(A().foo, 0)
        with self.assertRaises(AttributeError):
            A().foo_items

    def test_add_class_trait_when_trait_already_exists_in_subclass(self):
        class A(HasTraits):
            pass

        class B(A):
            foo = Int()

        A.add_class_trait("foo", Str())

        self.assertEqual(A().foo, "")
        self.assertEqual(B().foo, 0)

    def test_traits_method_with_dunder_metadata(self):
        # Regression test for enthought/envisage#430
        class A(HasTraits):
            foo = Int(__extension_point__=True)
            bar = Int(__extension_point__=False)
            baz = Int()

        a = A(foo=3, bar=4, baz=5)

        self.assertEqual(
            a.traits(__extension_point__=True),
            {"foo": a.trait("foo")},
        )

        self.assertEqual(
            A.class_traits(__extension_point__=True),
            {"foo": A.class_traits()["foo"]},
        )

    def test_decorated_changed_method(self):
        # xref: enthought/traits#527
        # Traits should ignore the _changed magic naming.

        events = []

        class A(HasTraits):
            foo = Int()

            @on_trait_change("foo")
            def _foo_changed(self, obj, name, old, new):
                events.append((obj, name, old, new))

        a = A()
        a.foo = 23
        self.assertEqual(
            events,
            [(a, "foo", 0, 23)],
        )

    def test_observed_changed_method(self):
        events = []

        class A(HasTraits):
            foo = Int()

            @observe("foo")
            def _foo_changed(self, event):
                events.append(event)

        a = A()
        a.foo = 23
        self.assertEqual(len(events), 1)
        event = events[0]
        self.assertEqual(event.object, a)
        self.assertEqual(event.name, "foo")
        self.assertEqual(event.old, 0)
        self.assertEqual(event.new, 23)

    def test_decorated_changed_method_subclass(self):
        # xref: enthought/traits#527
        # Traits should ignore the _changed magic naming.

        events = []

        class A(HasTraits):
            foo = Int()

            @on_trait_change("foo")
            def _foo_changed(self, obj, name, old, new):
                events.append((obj, name, old, new))

        class B(A):
            pass

        a = B()
        a.foo = 23
        self.assertEqual(
            events,
            [(a, "foo", 0, 23)],
        )


class TestObjectNotifiers(unittest.TestCase):
    """ Test calling object notifiers. """

    def test_notifiers_empty(self):

        class Foo(HasTraits):
            x = Int()

        foo = Foo(x=1)
        self.assertEqual(foo._notifiers(True), [])

    def test_notifiers_on_object(self):

        class Foo(HasTraits):
            x = Int()

        foo = Foo(x=1)
        self.assertEqual(foo._notifiers(True), [])

        # when
        def handler():
            pass

        foo.on_trait_change(handler, name="anytrait")

        # then
        notifiers = foo._notifiers(True)
        self.assertEqual(len(notifiers), 1)
        onotifier, = notifiers
        self.assertEqual(onotifier.handler, handler)


class TestCallNotifiers(unittest.TestCase):

    def test_trait_and_object_notifiers_called(self):

        side_effects = []

        class Foo(HasTraits):
            x = Int()
            y = Int()

            def _x_changed(self):
                side_effects.append("x")

        def object_handler():
            side_effects.append("object")

        foo = Foo()
        foo.on_trait_change(object_handler, name="anytrait")

        # when
        side_effects.clear()
        foo.x = 3

        # then
        self.assertEqual(side_effects, ["x", "object"])

        # when
        side_effects.clear()
        foo.y = 4

        # then
        self.assertEqual(side_effects, ["object"])

    def test_trait_notifier_modify_object_notifier(self):
        # Test when a trait notifier has a side effect of adding
        # an object notifier

        side_effects = []

        def object_handler1():
            side_effects.append("object1")

        def object_handler2():
            side_effects.append("object2")

        class Foo(HasTraits):
            x = Int()
            y = Int()

            def _x_changed(self):
                side_effects.append("x")

                # add the second object notifier
                self.on_trait_change(object_handler2, name="anytrait")

        # Add an object handler so that the list is created for mutation.
        foo = Foo()
        foo.on_trait_change(object_handler1, name="anytrait")

        # when
        side_effects.clear()
        foo.x = 1

        # then
        # the second object notifier is not called.
        self.assertEqual(side_effects, ["x", "object1"])

        # But the object notifier is added and will be used the next time
        # when
        side_effects.clear()
        foo.y = 2

        # then
        # the second object notifier is called.
        self.assertEqual(side_effects, ["object1", "object2"])


class TestDeprecatedHasTraits(unittest.TestCase):
    def test_deprecated(self):
        class TestSingletonHasTraits(SingletonHasTraits):
            pass

        class TestSingletonHasStrictTraits(SingletonHasStrictTraits):
            pass

        class TestSingletonHasPrivateTraits(SingletonHasPrivateTraits):
            pass

        with self.assertWarns(DeprecationWarning):
            TestSingletonHasTraits()

        with self.assertWarns(DeprecationWarning):
            TestSingletonHasStrictTraits()

        with self.assertWarns(DeprecationWarning):
            TestSingletonHasPrivateTraits()


class MappedWithDefault(HasTraits):

    married = Map({"yes": 1, "yeah": 1, "no": 0, "nah": 0})

    default_calls = Int(0)

    def _married_default(self):
        self.default_calls += 1
        return "yes"


class TestHasTraitsPickling(unittest.TestCase):

    def test_pickle_mapped_default_method(self):
        person = MappedWithDefault()

        # Sanity check
        self.assertEqual(person.default_calls, 0)

        reconstituted = pickle.loads(pickle.dumps(person))

        self.assertEqual(reconstituted.married_, 1)
        self.assertEqual(reconstituted.married, "yes")
        self.assertEqual(reconstituted.default_calls, 1)


class Person(HasTraits):
    age = Int()


class PersonWithObserve(Person):
    events = List()

    @observe(trait("age"))
    def handler(self, event):
        self.events.append(event)


class TestHasTraitsObserveHook(unittest.TestCase):
    """ Test observe decorator and the observe method.
    """

    def setUp(self):
        push_exception_handler(reraise_exceptions=True)
        self.addCleanup(pop_exception_handler)

    def test_overloaded_signature_expression(self):
        # Test the overloaded signature for expression
        expressions = [
            trait("age"),
            "age",
            [trait("age")],
            ["age"],
        ]
        for expression in expressions:

            class NewPerson(Person):
                events = List()

                @observe(expression)
                def handler(self, event):
                    self.events.append(event)

            person = NewPerson()
            person.age += 1
            self.assertEqual(len(person.events), 1)

    def test_observe_method_remove(self):
        events = []
        person = Person()
        person.observe(events.append, "age")

        # sanity check
        person.age += 1
        self.assertEqual(len(events), 1)

        # when
        person.observe(events.append, "age", remove=True)

        # then
        person.age += 1
        self.assertEqual(len(events), 1)  # unchanged

    def test_observe_method_remove_nonexistent_handler(self):
        events = []
        person = Person()

        with self.assertRaises(NotifierNotFound):
            person.observe(events.append, "age", remove=True)

    def test_observe_dispatch_ui(self):
        # Test to ensure "ui" is one of the allowed value
        # Not testing the actual effect as it requires GUI event loop
        # as well as assumption on the local thread identity while running
        # the test.
        person = Person()

        person.observe(repr, trait("age"), dispatch="ui")

    def test_inherit_observer_from_superclass(self):
        # Test observers can be inherited
        class BaseClass(HasTraits):
            events = List()

            @observe("value")
            def handler(self, event):
                self.events.append(event)

        class SubClass(BaseClass):
            value = Int()

        instance = SubClass()
        instance.value += 1
        self.assertEqual(len(instance.events), 1)

    def test_observer_overridden(self):
        # The handler is overridden, no change event should be registered.
        class BaseClass(HasTraits):
            events = List()

            @observe("value")
            def handler(self, event):
                self.events.append(event)

        class SubclassOverridden(BaseClass):
            value = Int()
            handler = None

        instance = SubclassOverridden()
        instance.value += 1
        self.assertEqual(len(instance.events), 0)

    def test_observe_post_init(self):

        class PersonWithPostInt(Person):
            events = List()

            @observe("age", post_init=True)
            def handler(self, event):
                self.events.append(event)

        person = PersonWithPostInt(age=10)
        self.assertEqual(len(person.events), 0)

        person.age += 1
        self.assertEqual(len(person.events), 1)

    def test_observe_pickability(self):
        # Test an HasTraits with observe can be pickled.
        person = PersonWithObserve()
        for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
            serialized = pickle.dumps(person, protocol=protocol)
            deserialized = pickle.loads(serialized)

            deserialized.age += 1
            self.assertEqual(len(deserialized.events), 1)

    def test_observe_deepcopy(self):
        # Test an HasTraits with observe can be deepcopied.
        person = PersonWithObserve()
        copied = copy.deepcopy(person)
        copied.age += 1
        self.assertEqual(len(copied.events), 1)
        self.assertEqual(len(person.events), 0)
