# (C) Copyright 2005-2022 Enthought, Inc., Austin, TX
# All rights reserved.
#
# This software is provided without warranty under the terms of the BSD
# license included in LICENSE.txt and may be redistributed only under
# the conditions described in the aforementioned license. The license
# is also available online at http://www.enthought.com/licenses/BSD.txt
#
# Thanks for using Enthought open source!

#  Imports

import unittest
import warnings

from traits.api import (
    Any,
    Bytes,
    CBytes,
    CFloat,
    CInt,
    ComparisonMode,
    Color,
    Delegate,
    Float,
    Font,
    HasTraits,
    Instance,
    Int,
    List,
    Range,
    RGBColor,
    Str,
    This,
    Trait,
    TraitError,
    TraitList,
    TraitPrefixList,
    TraitPrefixMap,
    Tuple,
    pop_exception_handler,
    push_exception_handler,
)
from traits.testing.optional_dependencies import requires_traitsui

#  Base unit test classes:


class BaseTest(object):
    def assign(self, value):
        self.obj.value = value

    def coerce(self, value):
        return value

    def test_assignment(self):
        obj = self.obj

        # Validate default value
        value = self._default_value
        self.assertEqual(obj.value, value)

        # Validate all legal values
        for i, value in enumerate(self._good_values):
            obj.value = value
            self.assertEqual(obj.value, self.coerce(value))

            # If there's a defined
            if i < len(self._mapped_values):
                self.assertEqual(obj.value_, self._mapped_values[i])

        # Validate correct behavior for illegal values
        for value in self._bad_values:
            self.assertRaises(TraitError, self.assign, value)


class test_base2(unittest.TestCase):
    def indexed_assign(self, list, index, value):
        list[index] = value

    def indexed_range_assign(self, list, index1, index2, value):
        list[index1:index2] = value

    def extended_slice_assign(self, list, index1, index2, step, value):
        list[index1:index2:step] = value

    # This avoids using a method name that contains 'test' so that this is not
    # called by the tester directly.
    def check_values(
        self,
        name,
        default_value,
        good_values,
        bad_values,
        actual_values=None,
        mapped_values=None,
    ):
        obj = self.obj

        # Make sure the default value is correct:
        value = default_value
        self.assertEqual(getattr(obj, name), value)

        # Iterate over all legal values being tested:
        if actual_values is None:
            actual_values = good_values
        i = 0
        for value in good_values:
            setattr(obj, name, value)
            self.assertEqual(getattr(obj, name), actual_values[i])
            if mapped_values is not None:
                self.assertEqual(
                    getattr(obj, name + "_"), mapped_values[i]
                )
            i += 1

        # Iterate over all illegal values being tested:
        for value in bad_values:
            self.assertRaises(TraitError, setattr, obj, name, value)


class AnyTrait(HasTraits):
    value = Any


class AnyTraitTest(BaseTest, unittest.TestCase):

    def setUp(self):
        self.obj = AnyTrait()

    _default_value = None
    _good_values = [10.0, b"ten", "ten", [10], {"ten": 10}, (10,), None, 1j]
    _mapped_values = []
    _bad_values = []


class CoercibleIntTrait(HasTraits):
    value = CInt(99)


class IntTrait(HasTraits):
    value = Int(99)


class CoercibleIntTest(AnyTraitTest):

    def setUp(self):
        self.obj = CoercibleIntTrait()

    _default_value = 99
    _good_values = [
        10,
        -10,
        10.1,
        -10.1,
        "10",
        "-10",
        b"10",
        b"-10",
    ]
    _bad_values = [
        "10L",
        "-10L",
        "10.1",
        "-10.1",
        b"10L",
        b"-10L",
        b"10.1",
        b"-10.1",
        "ten",
        b"ten",
        [10],
        {"ten": 10},
        (10,),
        None,
        1j,
    ]

    def coerce(self, value):
        try:
            return int(value)
        except:
            return int(float(value))


class IntTest(AnyTraitTest):

    def setUp(self):
        self.obj = IntTrait()

    _default_value = 99
    _good_values = [10, -10]
    _bad_values = [
        "ten",
        b"ten",
        [10],
        {"ten": 10},
        (10,),
        None,
        1j,
        10.1,
        -10.1,
        "10L",
        "-10L",
        "10.1",
        "-10.1",
        b"10L",
        b"-10L",
        b"10.1",
        b"-10.1",
        "10",
        "-10",
        b"10",
        b"-10",
    ]

    try:
        import numpy as np
    except ImportError:
        pass
    else:
        _good_values.extend(
            [
                np.int64(10),
                np.int64(-10),
                np.int32(10),
                np.int32(-10),
                np.int_(10),
                np.int_(-10),
            ]
        )

    def coerce(self, value):
        try:
            return int(value)
        except:
            return int(float(value))


class CoercibleFloatTrait(HasTraits):
    value = CFloat(99.0)


class FloatTrait(HasTraits):
    value = Float(99.0)


class CoercibleFloatTest(AnyTraitTest):
    def setUp(self):
        self.obj = CoercibleFloatTrait()

    _default_value = 99.0
    _good_values = [
        10,
        -10,
        10.1,
        -10.1,
        "10",
        "-10",
        "10.1",
        "-10.1",
        b"10",
        b"-10",
        b"10.1",
        b"-10.1",
    ]
    _bad_values = [
        "10L",
        "-10L",
        b"10L",
        b"-10L",
        "ten",
        b"ten",
        [10],
        {"ten": 10},
        (10,),
        None,
        1j,
    ]

    def coerce(self, value):
        return float(value)


class FloatTest(AnyTraitTest):
    def setUp(self):
        self.obj = FloatTrait()

    _default_value = 99.0
    _good_values = [10, -10, 10.1, -10.1]
    _bad_values = [
        "ten",
        b"ten",
        [10],
        {"ten": 10},
        (10,),
        None,
        1j,
        "10",
        "-10",
        "10L",
        "-10L",
        "10.1",
        "-10.1",
        b"10",
        b"-10",
        b"10L",
        b"-10L",
        b"10.1",
        b"-10.1",
    ]

    def coerce(self, value):
        return float(value)


#  Trait that can only have 'complex'(i.e. imaginary) values:


class ImaginaryValueTrait(HasTraits):
    value = Trait(99.0 - 99.0j)


class ImaginaryValueTest(AnyTraitTest):
    def setUp(self):
        self.obj = ImaginaryValueTrait()

    _default_value = 99.0 - 99.0j
    _good_values = [
        10,
        -10,
        10.1,
        -10.1,
        "10",
        "-10",
        "10.1",
        "-10.1",
        10j,
        10 + 10j,
        10 - 10j,
        10.1j,
        10.1 + 10.1j,
        10.1 - 10.1j,
        "10j",
        "10+10j",
        "10-10j",
    ]
    _bad_values = [b"10L", "-10L", "ten", [10], {"ten": 10}, (10,), None]

    def coerce(self, value):
        return complex(value)


class StringTrait(HasTraits):
    value = Trait("string")


class StringTest(AnyTraitTest):

    def setUp(self):
        self.obj = StringTrait()

    _default_value = "string"
    _good_values = [
        10,
        -10,
        10.1,
        -10.1,
        "10",
        "-10",
        "10L",
        "-10L",
        "10.1",
        "-10.1",
        "string",
        1j,
        [10],
        ["ten"],
        {"ten": 10},
        (10,),
        None,
    ]
    _bad_values = []

    def coerce(self, value):
        return str(value)


class BytesTrait(HasTraits):
    value = Bytes(b"bytes")


class BytesTest(StringTest):

    def setUp(self):
        self.obj = BytesTrait()

    _default_value = b"bytes"
    _good_values = [b"", b"10", b"-10"]
    _bad_values = [
        10,
        -10,
        10.1,
        [b""],
        [b"bytes"],
        [0],
        {b"ten": b"10"},
        (b"",),
        None,
        True,
        "",
        "string",
    ]

    def coerce(self, value):
        return bytes(value)


class CoercibleBytesTrait(HasTraits):
    value = CBytes(b"bytes")


class CoercibleBytesTest(StringTest):

    def setUp(self):
        self.obj = CoercibleBytesTrait()

    _default_value = b"bytes"
    _good_values = [
        b"",
        b"10",
        b"-10",
        10,
        [10],
        (10,),
        set([10]),
        {10: "foo"},
        True,
    ]
    _bad_values = [
        "",
        "string",
        -10,
        10.1,
        [b""],
        [b"bytes"],
        [-10],
        (-10,),
        {-10: "foo"},
        set([-10]),
        [256],
        (256,),
        {256: "foo"},
        set([256]),
        {b"ten": b"10"},
        (b"",),
        None,
    ]

    def coerce(self, value):
        return bytes(value)


class EnumTrait(HasTraits):
    value = Trait([1, "one", 2, "two", 3, "three", 4.4, "four.four"])


class EnumTest(AnyTraitTest):

    def setUp(self):
        self.obj = EnumTrait()

    _default_value = 1
    _good_values = [1, "one", 2, "two", 3, "three", 4.4, "four.four"]
    _bad_values = [0, "zero", 4, None]


class MappedTrait(HasTraits):
    value = Trait("one", {"one": 1, "two": 2, "three": 3})


class MappedTest(AnyTraitTest):
    def setUp(self):
        self.obj = MappedTrait()

    _default_value = "one"
    _good_values = ["one", "two", "three"]
    _mapped_values = [1, 2, 3]
    _bad_values = ["four", 1, 2, 3, [1], (1,), {1: 1}, None]


# Suppress DeprecationWarning from TraitPrefixList instantiation.
with warnings.catch_warnings():
    warnings.filterwarnings(action="ignore", category=DeprecationWarning)

    class PrefixListTrait(HasTraits):
        value = Trait("one", TraitPrefixList("one", "two", "three"))


class PrefixListTest(AnyTraitTest):
    def setUp(self):
        self.obj = PrefixListTrait()

    _default_value = "one"
    _good_values = [
        "o",
        "on",
        "one",
        "tw",
        "two",
        "th",
        "thr",
        "thre",
        "three",
    ]
    _bad_values = ["t", "one ", " two", 1, None]

    def coerce(self, value):
        return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]]


# Suppress DeprecationWarning from TraitPrefixMap instantiation.
with warnings.catch_warnings():
    warnings.filterwarnings(action="ignore", category=DeprecationWarning)

    class PrefixMapTrait(HasTraits):
        value = Trait("one", TraitPrefixMap({"one": 1, "two": 2, "three": 3}))


class PrefixMapTest(AnyTraitTest):
    def setUp(self):
        self.obj = PrefixMapTrait()

    _default_value = "one"
    _good_values = [
        "o",
        "on",
        "one",
        "tw",
        "two",
        "th",
        "thr",
        "thre",
        "three",
    ]
    _mapped_values = [1, 1, 1, 2, 2, 3, 3, 3]
    _bad_values = ["t", "one ", " two", 1, None]

    def coerce(self, value):
        return {"o": "one", "on": "one", "tw": "two", "th": "three"}[value[:2]]


# This test a combination of Trait, a default, a mapping and a function

def str_cast_to_int(object, name, value):
    """ A function that validates the value is a str and then converts
    it to an int using its length.
    """
    if not isinstance(value, str):
        raise TraitError("Not a string!")
    return len(value)


class TraitWithMappingAndCallable(HasTraits):

    value = Trait(
        "white",
        {"white": 0, "red": 1, (0, 0, 0): 999},
        str_cast_to_int,
    )


class TestTraitWithMappingAndCallable(unittest.TestCase):
    """ Test that demonstrates a usage of Trait where TraitMap is used but it
    cannot be replaced with Map. The callable causes the key value to be
    changed to match the mapped value.

    e.g. this would not work:

        value = Union(
            Map({"white": 0, "red": 1, (0,0,0): 999}),
            NewTraitType(),
            default_value="white",
        )

        where NewTraitType is a subclass of TraitType with ``validate`` simply
        calls str_cast_to_int
    """

    def test_trait_default(self):
        obj = TraitWithMappingAndCallable()

        # the value is not 'white' any more.
        self.assertEqual(obj.value, 5)
        self.assertEqual(obj.value_, 5)

    def test_trait_set_value_use_callable(self):
        obj = TraitWithMappingAndCallable(value="red")

        # The value is not 'red' any more.
        # the callable is used, not the mapping.
        self.assertEqual(obj.value, 3)
        self.assertEqual(obj.value_, 3)

    def test_trait_set_value_use_mapping(self):
        obj = TraitWithMappingAndCallable(value=(0, 0, 0))

        # Now this uses the mapping, and the value is the original one.
        self.assertEqual(obj.value, (0, 0, 0))
        self.assertEqual(obj.value_, 999)


# Old style class version:


class OTraitTest1:
    pass


class OTraitTest2(OTraitTest1):
    pass


class OTraitTest3(OTraitTest2):
    pass


class OBadTraitTest:
    pass


otrait_test1 = OTraitTest1()


class OldInstanceTrait(HasTraits):
    value = Trait(otrait_test1)


class OldInstanceTest(AnyTraitTest):
    def setUp(self):
        self.obj = OldInstanceTrait()

    _default_value = otrait_test1
    _good_values = [
        otrait_test1,
        OTraitTest1(),
        OTraitTest2(),
        OTraitTest3(),
        None,
    ]
    _bad_values = [
        0,
        0.0,
        0j,
        OTraitTest1,
        OTraitTest2,
        OBadTraitTest(),
        b"bytes",
        "string",
        [otrait_test1],
        (otrait_test1,),
        {"data": otrait_test1},
    ]


# New style class version:
class NTraitTest1(object):
    pass


class NTraitTest2(NTraitTest1):
    pass


class NTraitTest3(NTraitTest2):
    pass


class NBadTraitTest:
    pass


ntrait_test1 = NTraitTest1()


class NewInstanceTrait(HasTraits):
    value = Trait(ntrait_test1)


class NewInstanceTest(AnyTraitTest):
    def setUp(self):
        self.obj = NewInstanceTrait()

    _default_value = ntrait_test1
    _good_values = [
        ntrait_test1,
        NTraitTest1(),
        NTraitTest2(),
        NTraitTest3(),
        None,
    ]
    _bad_values = [
        0,
        0.0,
        0j,
        NTraitTest1,
        NTraitTest2,
        NBadTraitTest(),
        b"bytes",
        "string",
        [ntrait_test1],
        (ntrait_test1,),
        {"data": ntrait_test1},
    ]


class FactoryClass(HasTraits):
    pass


class ConsumerClass(HasTraits):
    x = Instance(FactoryClass, ())


class ConsumerSubclass(ConsumerClass):
    x = FactoryClass()


embedded_instance_trait = Trait(
    "", Str, Instance("traits.has_traits.HasTraits")
)


class Dummy(HasTraits):
    x = embedded_instance_trait
    xl = List(embedded_instance_trait)


class RegressionTest(unittest.TestCase):
    """ Check that fixed bugs stay fixed.
    """

    def test_factory_subclass_no_segfault(self):
        """ Test that we can provide an instance as a default in the definition
        of a subclass.
        """
        # There used to be a bug where this would segfault.
        obj = ConsumerSubclass()
        obj.x

    def test_trait_compound_instance(self):
        """ Test that a deferred Instance() embedded in a TraitCompound handler
        and then a list will not replace the validate method for the outermost
        trait.
        """
        # Pass through an instance in order to make the instance trait resolve
        # the class.
        d = Dummy()
        d.xl = [HasTraits()]
        d.x = "OK"


#  Trait(using a function) that must be an odd integer:


def odd_integer(object, name, value):
    try:
        float(value)
        if (value % 2) == 1:
            return int(value)
    except:
        pass
    raise TraitError


class OddIntegerTrait(HasTraits):
    value = Trait(99, odd_integer)


class OddIntegerTest(AnyTraitTest):
    def setUp(self):
        self.obj = OddIntegerTrait()

    _default_value = 99
    _good_values = [
        1,
        3,
        5,
        7,
        9,
        999999999,
        1.0,
        3.0,
        5.0,
        7.0,
        9.0,
        999999999.0,
        -1,
        -3,
        -5,
        -7,
        -9,
        -999999999,
        -1.0,
        -3.0,
        -5.0,
        -7.0,
        -9.0,
        -999999999.0,
    ]
    _bad_values = [0, 2, -2, 1j, None, "1", [1], (1,), {1: 1}]


class NotifierTraits(HasTraits):
    value1 = Int
    value2 = Int
    value1_count = Int
    value2_count = Int

    def _anytrait_changed(self, trait_name, old, new):
        if trait_name == "value1":
            self.value1_count += 1
        elif trait_name == "value2":
            self.value2_count += 1

    def _value1_changed(self, old, new):
        self.value1_count += 1

    def _value2_changed(self, old, new):
        self.value2_count += 1


class NotifierTests(unittest.TestCase):
    def setUp(self):
        obj = self.obj = NotifierTraits()
        obj.value1 = 0
        obj.value2 = 0
        obj.value1_count = 0
        obj.value2_count = 0

    def tearDown(self):
        obj = self.obj
        obj.on_trait_change(self.on_value1_changed, "value1", remove=True)
        obj.on_trait_change(self.on_value2_changed, "value2", remove=True)
        obj.on_trait_change(self.on_anytrait_changed, remove=True)

    def on_anytrait_changed(self, object, trait_name, old, new):
        if trait_name == "value1":
            self.obj.value1_count += 1
        elif trait_name == "value2":
            self.obj.value2_count += 1

    def on_value1_changed(self):
        self.obj.value1_count += 1

    def on_value2_changed(self):
        self.obj.value2_count += 1

    def test_simple(self):
        obj = self.obj

        obj.value1 = 1
        self.assertEqual(obj.value1_count, 2)
        self.assertEqual(obj.value2_count, 0)

        obj.value2 = 1
        self.assertEqual(obj.value1_count, 2)
        self.assertEqual(obj.value2_count, 2)

    def test_complex(self):
        obj = self.obj

        obj.on_trait_change(self.on_value1_changed, "value1")
        obj.value1 = 1
        self.assertEqual(obj.value1_count, 3)
        self.assertEqual(obj.value2_count, 0)

        obj.on_trait_change(self.on_value2_changed, "value2")
        obj.value2 = 1
        self.assertEqual(obj.value1_count, 3)
        self.assertEqual(obj.value2_count, 3)

        obj.on_trait_change(self.on_anytrait_changed)

        obj.value1 = 2
        self.assertEqual(obj.value1_count, 7)
        self.assertEqual(obj.value2_count, 3)

        obj.value1 = 2
        self.assertEqual(obj.value1_count, 7)
        self.assertEqual(obj.value2_count, 3)

        obj.value2 = 2
        self.assertEqual(obj.value1_count, 7)
        self.assertEqual(obj.value2_count, 7)

        obj.on_trait_change(self.on_value1_changed, "value1", remove=True)
        obj.value1 = 3
        self.assertEqual(obj.value1_count, 10)
        self.assertEqual(obj.value2_count, 7)

        obj.on_trait_change(self.on_value2_changed, "value2", remove=True)
        obj.value2 = 3
        self.assertEqual(obj.value1_count, 10)
        self.assertEqual(obj.value2_count, 10)

        obj.on_trait_change(self.on_anytrait_changed, remove=True)

        obj.value1 = 4
        self.assertEqual(obj.value1_count, 12)
        self.assertEqual(obj.value2_count, 10)

        obj.value2 = 4
        self.assertEqual(obj.value1_count, 12)
        self.assertEqual(obj.value2_count, 12)


class RaisesArgumentlessRuntimeError(HasTraits):
    x = Int(0)

    def _x_changed(self):
        raise RuntimeError


class TestRuntimeError(unittest.TestCase):
    def setUp(self):
        push_exception_handler(lambda *args: None, reraise_exceptions=True)

    def tearDown(self):
        pop_exception_handler()

    def test_runtime_error(self):
        f = RaisesArgumentlessRuntimeError()
        self.assertRaises(RuntimeError, setattr, f, "x", 5)


class DelegatedFloatTrait(HasTraits):
    value = Trait(99.0)


class DelegateTrait(HasTraits):
    value = Delegate("delegate")
    delegate = Trait(DelegatedFloatTrait())


class DelegateTrait2(DelegateTrait):
    delegate = Trait(DelegateTrait())


class DelegateTrait3(DelegateTrait):
    delegate = Trait(DelegateTrait2())


class DelegateTests(unittest.TestCase):
    def test_delegation(self):
        obj = DelegateTrait3()

        self.assertEqual(obj.value, 99.0)
        parent1 = obj.delegate
        parent2 = parent1.delegate
        parent3 = parent2.delegate
        parent3.value = 3.0
        self.assertEqual(obj.value, 3.0)
        parent2.value = 2.0
        self.assertEqual(obj.value, 2.0)
        self.assertEqual(parent3.value, 3.0)
        parent1.value = 1.0
        self.assertEqual(obj.value, 1.0)
        self.assertEqual(parent2.value, 2.0)
        self.assertEqual(parent3.value, 3.0)
        obj.value = 0.0
        self.assertEqual(obj.value, 0.0)
        self.assertEqual(parent1.value, 1.0)
        self.assertEqual(parent2.value, 2.0)
        self.assertEqual(parent3.value, 3.0)
        del obj.value
        self.assertEqual(obj.value, 1.0)
        del parent1.value
        self.assertEqual(obj.value, 2.0)
        self.assertEqual(parent1.value, 2.0)
        del parent2.value
        self.assertEqual(obj.value, 3.0)
        self.assertEqual(parent1.value, 3.0)
        self.assertEqual(parent2.value, 3.0)
        del parent3.value
        # Uncommenting the following line allows
        # the last assertions to pass. However, this
        # may not be intended behavior, so keeping
        # the line commented.
        # del parent2.value
        self.assertEqual(obj.value, 99.0)
        self.assertEqual(parent1.value, 99.0)
        self.assertEqual(parent2.value, 99.0)
        self.assertEqual(parent3.value, 99.0)


#  Complex(i.e. 'composite') Traits tests:

# Make a TraitCompound handler that does not have a fast_validate so we can
# check for a particular regression.
slow = Trait(1, Range(1, 3), Range(-3, -1))
try:
    del slow.handler.fast_validate
except AttributeError:
    pass


# Suppress DeprecationWarnings from TraitPrefixList and TraitPrefixMap
with warnings.catch_warnings():
    warnings.filterwarnings(action="ignore", category=DeprecationWarning)

    class complex_value(HasTraits):
        num1 = Trait(1, Range(1, 5), Range(-5, -1))
        num2 = Trait(
            1,
            Range(1, 5),
            TraitPrefixList("one", "two", "three", "four", "five"),
        )
        num3 = Trait(
            1,
            Range(1, 5),
            TraitPrefixMap(
                {"one": 1, "two": 2, "three": 3, "four": 4, "five": 5}
            ),
        )
        num4 = Trait(1, Trait(1, Tuple, slow), 10)
        num5 = Trait(1, 10, Trait(1, Tuple, slow))


class test_complex_value(test_base2):
    def setUp(self):
        self.obj = complex_value()

    def test_num1(self):
        self.check_values(
            "num1",
            1,
            [1, 2, 3, 4, 5, -1, -2, -3, -4, -5],
            [
                0,
                6,
                -6,
                "0",
                "6",
                "-6",
                0.0,
                6.0,
                -6.0,
                [1],
                (1,),
                {1: 1},
                None,
            ],
            [1, 2, 3, 4, 5, -1, -2, -3, -4, -5],
        )

    def test_enum_exceptions(self):
        """ Check that enumerated values can be combined with nested
        TraitCompound handlers.
        """
        self.check_values(
            "num4", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11]
        )
        self.check_values(
            "num5", 1, [1, 2, 3, -3, -2, -1, 10, ()], [0, 4, 5, -5, -4, 11]
        )


class test_list_value(test_base2):
    def setUp(self):
        with self.assertWarns(DeprecationWarning):

            class list_value(HasTraits):
                # Trait definitions:
                list1 = Trait([2], TraitList(Trait([1, 2, 3, 4]), maxlen=4))
                list2 = Trait(
                    [2], TraitList(Trait([1, 2, 3, 4]), minlen=1, maxlen=4)
                )
                alist = List()

        self.obj = list_value()
        self.last_event = None

    def tearDown(self):
        del self.last_event

    def del_range(self, list, index1, index2):
        del list[index1:index2]

    def del_extended_slice(self, list, index1, index2, step):
        del list[index1:index2:step]

    def check_list(self, list):
        self.assertEqual(list, [2])
        self.assertEqual(len(list), 1)
        list.append(3)
        self.assertEqual(len(list), 2)
        list[1] = 2
        self.assertEqual(list[1], 2)
        self.assertEqual(len(list), 2)
        list[0] = 1
        self.assertEqual(list[0], 1)
        self.assertEqual(len(list), 2)
        self.assertRaises(TraitError, self.indexed_assign, list, 0, 5)
        self.assertRaises(TraitError, list.append, 5)
        self.assertRaises(TraitError, list.extend, [1, 2, 3])
        list.extend([3, 4])
        self.assertEqual(list, [1, 2, 3, 4])
        self.assertRaises(TraitError, list.append, 1)
        self.assertRaises(
            ValueError, self.extended_slice_assign, list, 0, 4, 2, [4, 5, 6]
        )
        del list[1]
        self.assertEqual(list, [1, 3, 4])
        del list[0]
        self.assertEqual(list, [3, 4])
        list[:0] = [1, 2]
        self.assertEqual(list, [1, 2, 3, 4])
        self.assertRaises(
            TraitError, self.indexed_range_assign, list, 0, 0, [1]
        )
        del list[0:3]
        self.assertEqual(list, [4])
        self.assertRaises(
            TraitError, self.indexed_range_assign, list, 0, 0, [4, 5]
        )

    def test_list1(self):
        self.check_list(self.obj.list1)

    def test_list2(self):
        self.check_list(self.obj.list2)
        self.assertRaises(TraitError, self.del_range, self.obj.list2, 0, 1)
        self.assertRaises(
            TraitError, self.del_extended_slice, self.obj.list2, 4, -5, -1
        )

    def assertLastTraitListEventEqual(self, index, removed, added):
        self.assertEqual(self.last_event.index, index)
        self.assertEqual(self.last_event.removed, removed)
        self.assertEqual(self.last_event.added, added)

    def test_trait_list_event(self):
        """ Record TraitListEvent behavior.
        """
        self.obj.alist = [1, 2, 3, 4]
        self.obj.on_trait_change(self._record_trait_list_event, "alist_items")
        del self.obj.alist[0]
        self.assertLastTraitListEventEqual(0, [1], [])
        self.obj.alist.append(5)
        self.assertLastTraitListEventEqual(3, [], [5])
        self.obj.alist[0:2] = [6, 7]
        self.assertLastTraitListEventEqual(0, [2, 3], [6, 7])
        self.obj.alist[:2] = [4, 5]
        self.assertLastTraitListEventEqual(0, [6, 7], [4, 5])
        self.obj.alist[0:2:1] = [8, 9]
        self.assertLastTraitListEventEqual(0, [4, 5], [8, 9])
        self.obj.alist[0:2:1] = [8, 9]
        # If list values stay the same, a new TraitListEvent will be generated.
        self.assertLastTraitListEventEqual(0, [8, 9], [8, 9])
        old_event = self.last_event
        self.obj.alist[4:] = []
        # If no structural change, NO new TraitListEvent will be generated.
        self.assertIs(self.last_event, old_event)
        self.obj.alist[0:4:2] = [10, 11]
        self.assertLastTraitListEventEqual(
            slice(0, 3, 2), [8, 4], [10, 11]
        )
        del self.obj.alist[1:4:2]
        self.assertLastTraitListEventEqual(slice(1, 4, 2), [9, 5], [])
        self.obj.alist = [1, 2, 3, 4]
        del self.obj.alist[2:4]
        self.assertLastTraitListEventEqual(2, [3, 4], [])
        self.obj.alist[:0] = [5, 6, 7, 8]
        self.assertLastTraitListEventEqual(0, [], [5, 6, 7, 8])
        del self.obj.alist[:2]
        self.assertLastTraitListEventEqual(0, [5, 6], [])
        del self.obj.alist[0:2]
        self.assertLastTraitListEventEqual(0, [7, 8], [])
        del self.obj.alist[:]
        self.assertLastTraitListEventEqual(0, [1, 2], [])

    def _record_trait_list_event(self, object, name, old, new):
        self.last_event = new


class ThisDummy(HasTraits):
    allows_none = This()
    disallows_none = This(allow_none=False)


class TestThis(unittest.TestCase):
    def test_this_none(self):
        d = ThisDummy()
        self.assertIsNone(d.allows_none)
        d.allows_none = None
        d.allows_none = ThisDummy()
        self.assertIsNotNone(d.allows_none)
        d.allows_none = None
        self.assertIsNone(d.allows_none)

        # Still starts out as None, unavoidably.
        self.assertIsNone(d.disallows_none)
        d.disallows_none = ThisDummy()
        self.assertIsNotNone(d.disallows_none)
        with self.assertRaises(TraitError):
            d.disallows_none = None
        self.assertIsNotNone(d.disallows_none)

    def test_this_other_class(self):
        d = ThisDummy()
        with self.assertRaises(TraitError):
            d.allows_none = object()
        self.assertIsNone(d.allows_none)


class ComparisonModeTests(unittest.TestCase):
    def test_comparison_mode_none(self):
        class HasComparisonMode(HasTraits):
            bar = Trait(comparison_mode=ComparisonMode.none)

        old_compare = HasComparisonMode()
        events = []
        old_compare.on_trait_change(lambda: events.append(None), "bar")

        some_list = [1, 2, 3]

        self.assertEqual(len(events), 0)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = some_list
        self.assertEqual(len(events), 2)
        old_compare.bar = [1, 2, 3]
        self.assertEqual(len(events), 3)
        old_compare.bar = [4, 5, 6]
        self.assertEqual(len(events), 4)

    def test_comparison_mode_identity(self):
        class HasComparisonMode(HasTraits):
            bar = Trait(comparison_mode=ComparisonMode.identity)

        old_compare = HasComparisonMode()
        events = []
        old_compare.on_trait_change(lambda: events.append(None), "bar")

        some_list = [1, 2, 3]

        self.assertEqual(len(events), 0)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = [1, 2, 3]
        self.assertEqual(len(events), 2)
        old_compare.bar = [4, 5, 6]
        self.assertEqual(len(events), 3)

    def test_comparison_mode_equality(self):
        class HasComparisonMode(HasTraits):
            bar = Trait(comparison_mode=ComparisonMode.equality)

        old_compare = HasComparisonMode()
        events = []
        old_compare.on_trait_change(lambda: events.append(None), "bar")

        some_list = [1, 2, 3]

        self.assertEqual(len(events), 0)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = [1, 2, 3]
        self.assertEqual(len(events), 1)
        old_compare.bar = [4, 5, 6]
        self.assertEqual(len(events), 2)

    def test_rich_compare_false(self):
        with warnings.catch_warnings(record=True) as warn_msgs:
            warnings.simplefilter("always", DeprecationWarning)

            class OldRichCompare(HasTraits):
                bar = Trait(rich_compare=False)

        # Check for a DeprecationWarning.
        self.assertEqual(len(warn_msgs), 1)
        warn_msg = warn_msgs[0]
        self.assertIs(warn_msg.category, DeprecationWarning)
        self.assertIn(
            "'rich_compare' metadata has been deprecated",
            str(warn_msg.message)
        )
        _, _, this_module = __name__.rpartition(".")
        self.assertIn(this_module, warn_msg.filename)

        # Behaviour matches comparison_mode=ComparisonMode.identity.
        old_compare = OldRichCompare()
        events = []
        old_compare.on_trait_change(lambda: events.append(None), "bar")

        some_list = [1, 2, 3]

        self.assertEqual(len(events), 0)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = [1, 2, 3]
        self.assertEqual(len(events), 2)
        old_compare.bar = [4, 5, 6]
        self.assertEqual(len(events), 3)

    def test_rich_compare_true(self):
        with warnings.catch_warnings(record=True) as warn_msgs:
            warnings.simplefilter("always", DeprecationWarning)

            class OldRichCompare(HasTraits):
                bar = Trait(rich_compare=True)

        # Check for a DeprecationWarning.
        self.assertEqual(len(warn_msgs), 1)
        warn_msg = warn_msgs[0]
        self.assertIs(warn_msg.category, DeprecationWarning)
        self.assertIn(
            "'rich_compare' metadata has been deprecated",
            str(warn_msg.message)
        )
        _, _, this_module = __name__.rpartition(".")
        self.assertIn(this_module, warn_msg.filename)

        # Behaviour matches comparison_mode=ComparisonMode.identity.
        old_compare = OldRichCompare()
        events = []
        old_compare.on_trait_change(lambda: events.append(None), "bar")

        some_list = [1, 2, 3]

        self.assertEqual(len(events), 0)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = some_list
        self.assertEqual(len(events), 1)
        old_compare.bar = [1, 2, 3]
        self.assertEqual(len(events), 1)
        old_compare.bar = [4, 5, 6]
        self.assertEqual(len(events), 2)


@requires_traitsui
class TestDeprecatedTraits(unittest.TestCase):

    def test_color_deprecated(self):
        with self.assertWarnsRegex(DeprecationWarning, "'Color' in 'traits'"):
            Color()

    def test_rgb_color_deprecated(self):
        with self.assertWarnsRegex(DeprecationWarning,
                                   "'RGBColor' in 'traits'"):
            RGBColor()

    def test_font_deprecated(self):
        with self.assertWarnsRegex(DeprecationWarning, "'Font' in 'traits'"):
            Font()
