diff --git a/typen/_enforcer.py b/typen/_enforcer.py index 54bb039..9ec74b5 100644 --- a/typen/_enforcer.py +++ b/typen/_enforcer.py @@ -2,6 +2,7 @@ from traits.api import HasTraits, TraitError +from typen._typing import typing_to_trait from typen.exceptions import ( ParameterTypeError, ReturnTypeError, @@ -89,10 +90,18 @@ def __init__( raise UnspecifiedReturnTypeError(msg.format(func.__name__)) self.returns = UNSPECIFIED + # Convert any non-trait type hints to traits + spec = { + k: to_traitable(v) for k, v in spec.items() + } + self.returns = to_traitable(self.returns) + self.packed_args_spec = to_traitable(self.packed_args_spec) + self.packed_kwargs_spec = to_traitable(self.packed_kwargs_spec) + # Restore order of args self.args = [Arg(k, spec[k]) for k in params.keys()] - # Validate defaults + # Store defaults for validation self.default_kwargs = { k: v.default for k, v in params.items() if v.default is not inspect.Parameter.empty @@ -222,6 +231,18 @@ def verify_result(self, value): raise exception from None +def to_traitable(param_type): + """ + Function to attempt to turn the input parameter type to a suitable Traits + type. + """ + param_type = typing_to_trait(param_type) + # More conversion attempts can be made if more than just typing types + # are supported + + return param_type + + class FunctionSignature(HasTraits): pass diff --git a/typen/_typing.py b/typen/_typing.py new file mode 100644 index 0000000..2841cf4 --- /dev/null +++ b/typen/_typing.py @@ -0,0 +1,39 @@ +import typing + +import traits.api as traits_api + +from typen.exceptions import TypenError +from typen.traits import ValidatedList + + +def typing_to_trait(arg_type): + """ + Attempt to convert a ``typing`` type into an appropriate ``traits`` type + + Raises + ------ + TypenError + If the input type is a ``typing`` type but it could not be converted + to a traits type. This may be because the type is not currently + supported. + """ + + if not hasattr(arg_type, "__origin__"): + return arg_type + + origin = arg_type.__origin__ or arg_type + + if origin in [typing.List, list]: + if arg_type.__args__ is not None and arg_type.__args__[0] is not typing.Any: + contained = arg_type.__args__[0] + return ValidatedList(typing_to_trait(contained)) + else: + return traits_api.List() + elif origin in [typing.Tuple, tuple]: + if arg_type.__args__ is not None: + contained = [typing_to_trait(arg) for arg in arg_type.__args__] + return traits_api.Tuple(*contained) + else: + return traits_api.Tuple() + + raise TypenError("Could not convert {} to trait".format(arg_type)) diff --git a/typen/tests/test_typing.py b/typen/tests/test_typing.py new file mode 100644 index 0000000..4105eeb --- /dev/null +++ b/typen/tests/test_typing.py @@ -0,0 +1,141 @@ +import typing +import unittest + +import traits.api as traits_api + +from typen._enforcer import Enforcer +from typen._typing import typing_to_trait +from typen.exceptions import ParameterTypeError +from typen.traits import ValidatedList + + +class TypingToTrait(unittest.TestCase): + def test_typing_to_trait_list(self): + typ = typing.List + + traits_typ = typing_to_trait(typ) + + self.assertIsInstance(traits_typ, traits_api.List) + self.assertNotIsInstance(traits_typ, ValidatedList) + + traits_typ.validate(None, None, [1]) + traits_typ.validate(None, None, ["a"]) + with self.assertRaises(traits_api.TraitError): + traits_typ.validate(None, None, "a") + + def test_typing_to_trait_list_of_int(self): + typ = typing.List[int] + + traits_typ = typing_to_trait(typ) + + self.assertIsInstance(traits_typ, traits_api.List) + self.assertIsInstance(traits_typ, ValidatedList) + + traits_typ.validate(None, None, [1]) + + with self.assertRaises(traits_api.TraitError): + traits_typ.validate(None, None, ["a"]) + + def test_typing_to_trait_nested_list(self): + typ = typing.List[typing.List[str]] + + traits_typ = typing_to_trait(typ) + + self.assertIsInstance(traits_typ, traits_api.List) + self.assertIsInstance(traits_typ, ValidatedList) + + traits_typ.validate(None, None, [["a", "b"], ["c", "d"]]) + + with self.assertRaises(traits_api.TraitError): + traits_typ.validate(None, None, ["a"]) + + with self.assertRaises(traits_api.TraitError): + traits_typ.validate(None, None, [[1, "b"], ["c", "d"]]) + + def test_typing_to_trait_int(self): + typ = int + + traits_typ = typing_to_trait(typ) + self.assertIs(traits_typ, int) + + def test_typing_to_trait_tuple(self): + typ = typing.Tuple[int, str, int] + + traits_typ = typing_to_trait(typ) + + self.assertIsInstance(traits_typ, traits_api.Tuple) + + #self.fail("Complete test") + + +class EnforceTypingTypes(unittest.TestCase): + def test_enforce_typing_list(self): + def test_function(a: typing.List): + pass + e = Enforcer(test_function) + + e.verify_args([[1, 2]], {}) + e.verify_args([[1.1, 0.1]], {}) + e.verify_args([["a", "b"]], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([1], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([(1, 2)], {}) + + def test_enforce_typing_list_spec(self): + def test_function(a: typing.List[int]): + pass + e = Enforcer(test_function) + + e.verify_args([[1, 2]], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([[1.1, 0.1]], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([["a", "b"]], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([1], {}) + with self.assertRaises(ParameterTypeError): + e.verify_args([(1, 2)], {}) + + def test_enforce_typing_tuple(self): + def test_function(a: typing.Tuple): + pass + e = Enforcer(test_function) + + e.verify_args([(1, 2)], {}) + + e.verify_args([[1.1, 0.1]], {}) + + # Lists can be cast to tuple + e.verify_args([["a", "b"]], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([1], {}) + + def test_enforce_typing_tuple_spec(self): + def test_function(a: typing.Tuple[int, int]): + pass + e = Enforcer(test_function) + + e.verify_args([(1, 2)], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([(1.1, 0.1)], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([(1, "b")], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([("a", "b")], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args((1,), {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([(1, 2, 3)], {}) + + with self.assertRaises(ParameterTypeError): + e.verify_args([[1, 2]], {}) + + +#TODO: return types, args, kwargs diff --git a/typen/traits.py b/typen/traits.py new file mode 100644 index 0000000..2253776 --- /dev/null +++ b/typen/traits.py @@ -0,0 +1,13 @@ +from traits.api import List + + +class ValidatedList(List): + """ + Defines a list that does validation on the internal type + """ + + def validate(self, object, name, value): + value = super(ValidatedList, self).validate(object, name, value) + + for item in value: + self.item_trait.validate(object, name, item)