Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 131 additions & 38 deletions src/assertical/fake/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Optional,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
Expand Down Expand Up @@ -55,6 +54,12 @@ class CollectionType(IntEnum):
OPTIONAL_LIST = auto() # For type T - represents list[Optional[T]]
REQUIRED_SET = auto() # For type T - represents set[T]
OPTIONAL_SET = auto() # For type T - represents set[Optional[T]]
REQUIRED_DICT = auto()
OPTIONAL_DICT = auto()


SUPPORTED_COLLECTION_TYPES = {list, dict, set}
TWO_PARAMETER_COLLECTION_TYPES: set[CollectionType] = {CollectionType.OPTIONAL_DICT, CollectionType.REQUIRED_DICT}


@dataclass
Expand Down Expand Up @@ -85,12 +90,22 @@ class PropertyGenerationDetails:
# For example, a list[int] would have type_to_generate as int and this property as REQUIRED_LIST
collection_type: Optional[CollectionType]

second_type_to_generate: Optional[type] = None
Comment thread
mikejturner marked this conversation as resolved.
second_is_primitive_type: Optional[bool] = None
second_is_optional: Optional[bool] = None


@dataclass
class _PlaceholderDataclassBase:
"""Dataclass has no base class - instead we fall back to using this as a placeholder"""


@dataclass
class _PlaceholderCollectionBase:
"""lists, dicts and sets have no base class other than object
- instead we fall back to using this as a placeholder"""


AnyType = TypeVar("AnyType")


Expand Down Expand Up @@ -248,6 +263,9 @@ def get_generatable_class_base(t: type) -> Optional[type]:
if optional_arg is not None:
target_type = optional_arg

if get_origin(target_type) in SUPPORTED_COLLECTION_TYPES:
return _PlaceholderCollectionBase

if not inspect.isclass(target_type):
return None

Expand Down Expand Up @@ -305,7 +323,7 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails,
if t_generatable_base is None:
raise Exception(f"Type {t} does not inherit from one of {CLASS_INSTANCE_GENERATORS.keys()}")

type_hints = get_type_hints(t)
type_hints = TYPE_HINT_FETCHER[t_generatable_base](t)

for member_name in CLASS_MEMBER_FETCHERS[t_generatable_base](t):

Expand All @@ -320,8 +338,12 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails,
collection_type: Optional[CollectionType] = None
is_optional: bool = False
is_primitive: bool = False
second_type_to_generate: Optional[type] = None
second_is_primitive: Optional[bool] = None
second_is_optional: Optional[bool] = None

if member_name in type_hints:
declared_type = cast(type, type_hints[member_name])
declared_type = type_hints[member_name]
member_type = remove_passthrough_type(declared_type)
optional_arg_type = get_optional_type_argument(member_type)
is_optional = optional_arg_type is not None
Expand All @@ -332,13 +354,37 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails,
collection_type = CollectionType.OPTIONAL_LIST
elif get_origin(optional_arg_type) == set:
collection_type = CollectionType.OPTIONAL_SET
elif get_origin(optional_arg_type) == dict:
collection_type = CollectionType.OPTIONAL_DICT
else:
if get_origin(member_type) == list:
collection_type = CollectionType.REQUIRED_LIST
elif get_origin(member_type) == set:
collection_type = CollectionType.REQUIRED_SET
elif get_origin(member_type) == dict:
collection_type = CollectionType.REQUIRED_DICT

if collection_type is not None:
# Determine second argument (if required)
if collection_type in TWO_PARAMETER_COLLECTION_TYPES:
second_member_type = get_args(optional_arg_type)[1] if is_optional else get_args(member_type)[1]
second_optional_arg_type = get_optional_type_argument(second_member_type)
second_is_optional = second_optional_arg_type is not None
if collection_type in (CollectionType.OPTIONAL_DICT, CollectionType.REQUIRED_DICT):
Comment thread
mikejturner marked this conversation as resolved.
if is_generatable_type(second_member_type):
second_type_to_generate = get_first_generatable_primitive(
second_member_type, include_optional=False
)
assert (
second_type_to_generate is not None
), f"Error generating member {member_name}. Couldn't find type for {second_member_type}"
second_is_primitive = True
elif get_generatable_class_base(second_member_type) is not None:
second_type_to_generate = (
second_optional_arg_type if second_is_optional else second_member_type
)

# Determine first argument
member_type = get_args(optional_arg_type)[0] if is_optional else get_args(member_type)[0]
optional_arg_type = get_optional_type_argument(member_type)
is_optional = optional_arg_type is not None
Expand Down Expand Up @@ -373,6 +419,9 @@ def enumerate_class_properties(t: type) -> Generator[PropertyGenerationDetails,
is_primitive_type=is_primitive,
is_optional=is_optional,
collection_type=collection_type,
second_type_to_generate=second_type_to_generate,
second_is_primitive_type=second_is_primitive,
second_is_optional=second_is_optional,
)


Expand All @@ -381,9 +430,10 @@ def generate_class_instance( # noqa: C901
seed: int = 1,
optional_is_none: bool = False,
generate_relationships: bool = False,
_return_seed: bool = False,
_visited_type_stack: Optional[list[type]] = None,
**kwargs: Any,
) -> AnyType:
) -> Union[AnyType, tuple[AnyType, int]]:
"""Given a child class of a key to CLASS_INSTANCE_GENERATORS - generate an instance of that class
with all properties being assigned unique values based off of seed. The values will match type hints

Expand All @@ -407,7 +457,8 @@ def generate_class_instance( # noqa: C901
if _visited_type_stack is None:
_visited_type_stack = []
if t in _visited_type_stack:
return None # type: ignore # This only happens in recursion - the top level object will never be None
# This only happens in recursion - the top level object will never be None
return (None, seed) if _return_seed else None # type: ignore
_visited_type_stack.append(t)

# We can only generate class instances of classes that inherit from a known base
Expand All @@ -430,17 +481,46 @@ def generate_class_instance( # noqa: C901
continue

if member.type_to_generate is None:
raise Exception(
f"Type {t} has property {member.name} with type {member.declared_type} that cannot be generated"
)
# Don't raise exception for ungeneratable types if their value is going to be None
if not (optional_is_none and member.is_optional):
raise Exception(
f"Type {t} has property {member.name} with type {member.declared_type} that cannot be generated"
)

generated_value: Any = None
empty_collection: bool = False
collection_type: Optional[CollectionType] = member.collection_type

def generate_member(
is_primitive_type: bool, type_to_generate: type, current_seed: int, empty_collection: bool
) -> tuple[Any, int, bool]:
if is_primitive_type:
generated_value = generate_value(type_to_generate, seed=current_seed, optional_is_none=optional_is_none)
current_seed += 1
else:
generated_value = None
if generate_relationships:
generated_value, current_seed = generate_class_instance(
type_to_generate,
seed=current_seed,
optional_is_none=optional_is_none,
generate_relationships=generate_relationships,
_visited_type_stack=_visited_type_stack,
_return_seed=True,
)

# None can be generated when Type A has child B that includes a backreference to A. in these
# circumstances the visited_types short circuit will just return None from generate_class_instance
# (to stop infinite recursion) The way we handle this is to just generate an empty list (if this is
# a list entity)
if generated_value is None:
empty_collection = True

return generated_value, current_seed, empty_collection

if optional_is_none and (
member.collection_type == CollectionType.OPTIONAL_LIST
or member.collection_type == CollectionType.OPTIONAL_SET
member.collection_type
in [CollectionType.OPTIONAL_LIST, CollectionType.OPTIONAL_SET, CollectionType.OPTIONAL_DICT]
):
# We can short circuit some generation if we know the top level collection should be None
# In this case - we just set everything to None
Expand All @@ -452,38 +532,32 @@ def generate_class_instance( # noqa: C901
# that are None - so we just add a None to the parent collection (or just generate None)
generated_value = None
current_seed += 1
elif member.is_primitive_type:
generated_value = generate_value(
member.type_to_generate, seed=current_seed, optional_is_none=optional_is_none
)
current_seed += 1
else:
if generate_relationships:
generated_value = generate_class_instance(
member.type_to_generate,
seed=current_seed,
optional_is_none=optional_is_none,
generate_relationships=generate_relationships,
_visited_type_stack=_visited_type_stack,
)

# None can be generated when Type A has child B that includes a backreference to A. in these
# circumstances the visited_types short circuit will just return None from generate_class_instance
# (to stop infinite recursion) The way we handle this is to just generate an empty list (if this is
# a list entity)
if generated_value is None:
empty_collection = True
# collection_type = CollectionType.REQUIRED_LIST
else:
# In this case we have a complex type but we aren't generating relationships - throw in a placeholder
empty_collection = True
generated_value = None
current_seed += 1000 # Rather than calculating how many seed values were utilised - set it arbitrarily high
generated_value, current_seed, empty_collection = generate_member(
is_primitive_type=member.is_primitive_type,
type_to_generate=member.type_to_generate, # type: ignore
current_seed=current_seed,
empty_collection=empty_collection,
)

if collection_type == CollectionType.REQUIRED_LIST or collection_type == CollectionType.OPTIONAL_LIST:
values[member.name] = [] if empty_collection else [generated_value]
elif collection_type == CollectionType.REQUIRED_SET or collection_type == CollectionType.OPTIONAL_SET:
values[member.name] = set([]) if empty_collection else set([generated_value])
elif collection_type == CollectionType.REQUIRED_DICT or collection_type == CollectionType.OPTIONAL_DICT:
if optional_is_none and member.second_is_optional:
# In this case the parent collection is NOT able to be set to None but does support adding items
# that are None - so we just add a None to the parent collection (or just generate None)
second_generated_value = None
current_seed += 1
else:
second_generated_value, current_seed, empty_collection = generate_member(
is_primitive_type=member.second_is_primitive_type, # type: ignore
type_to_generate=member.second_type_to_generate, # type: ignore
current_seed=current_seed,
empty_collection=empty_collection,
)
values[member.name] = {} if empty_collection else {generated_value: second_generated_value}
else:
values[member.name] = generated_value

Expand All @@ -492,7 +566,9 @@ def generate_class_instance( # noqa: C901
raise Exception(f"The following kwargs were unused {expected_kwargs_references.difference(kwargs_references)}")

_visited_type_stack.pop() # When we finish generating a type, allow recursion back into that type
return CLASS_INSTANCE_GENERATORS[t_generatable_base](t, values)

instance = CLASS_INSTANCE_GENERATORS[t_generatable_base](t, values)
return (instance, current_seed) if _return_seed else instance


def clone_class_instance(obj: AnyType, ignored_properties: Optional[set[str]] = None) -> AnyType:
Expand Down Expand Up @@ -627,12 +703,18 @@ def register_value_generator(t: type, generator: Callable[[int], Any]) -> None:
BASE_CLASS_PUBLIC_MEMBERS: dict[type, set[str]] = {}
DEFAULT_CLASS_INSTANCE_GENERATOR: Callable[[type, dict[str, Any]], Any] = lambda target, kwargs: target(**kwargs)
DEFAULT_MEMBER_FETCHER: Callable[[type], list[str]] = lambda target: [name for (name, _) in inspect.getmembers(target)]
DEFAULT_PUBLIC_MEMBER_CHECKER: Callable[[str], bool] = is_member_public

TYPE_HINT_FETCHER: dict[type, Callable[[type], dict[str, type]]] = {}
DEFAULT_TYPE_HINT_FETCHER: Callable[[type], dict[str, type]] = get_type_hints


def register_base_type(
base_type: type,
instance_generator: Callable[[type, dict[str, Any]], Any],
member_fetcher: Callable[[type], list[str]],
public_member_checker: Callable[[str], bool] = DEFAULT_PUBLIC_MEMBER_CHECKER,
type_hint_fetcher: Callable[[type], dict[str, type]] = DEFAULT_TYPE_HINT_FETCHER,
) -> None:
"""Registers a type that will allow all subclasses to be generated/cloned by functions in this module.

Expand All @@ -646,7 +728,8 @@ def register_base_type(
polluting the global registry"""
CLASS_INSTANCE_GENERATORS[base_type] = instance_generator
CLASS_MEMBER_FETCHERS[base_type] = member_fetcher
BASE_CLASS_PUBLIC_MEMBERS[base_type] = set([m for m in member_fetcher(base_type) if is_member_public(m)])
BASE_CLASS_PUBLIC_MEMBERS[base_type] = set([m for m in member_fetcher(base_type) if public_member_checker(m)])
TYPE_HINT_FETCHER[base_type] = type_hint_fetcher


# Base type registration
Expand All @@ -656,6 +739,16 @@ def register_base_type(
lambda target: [f.name for f in fields(target) if f.init],
)

# Handling of collections
register_base_type(
_PlaceholderCollectionBase,
lambda target, kwargs: kwargs["self"],
lambda _: ["self"],
lambda _: False, # "base class" doesn't have any public members
lambda target: {"self": target},
)


if "pydantic_xml" in sys.modules:
register_base_type(
BaseXmlModel,
Expand Down
5 changes: 5 additions & 0 deletions tests/fake/test_generator_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def test_generate_value():
generate_value(RandomOtherClass, 1)
with pytest.raises(Exception):
generate_value(list[int], 1)
with pytest.raises(Exception):
generate_value(dict[str, int], 1)

assert generate_value(str, 1, True) == generate_value(str, 1, True)
assert generate_value(str, 1, True) is not generate_value(str, 1, True)
Expand Down Expand Up @@ -230,6 +232,7 @@ def test_is_passthrough_type():
assert not is_passthrough_type(Union[str, int])
assert not is_passthrough_type(str)
assert not is_passthrough_type(list[int])
assert not is_passthrough_type(dict[str, int])


def test_remove_passthrough_type():
Expand Down Expand Up @@ -299,6 +302,7 @@ def test_get_first_generatable_primitive():
assert get_first_generatable_primitive(list[str], include_optional=True) is None
assert get_first_generatable_primitive(list[int], include_optional=True) is None
assert get_first_generatable_primitive(Mapped[list[str]], include_optional=True) is None
assert get_first_generatable_primitive(dict[str, int], include_optional=True) is None

# With include_optional disabled
assert get_first_generatable_primitive(int, include_optional=False) == int
Expand All @@ -318,6 +322,7 @@ def test_get_first_generatable_primitive():
assert get_first_generatable_primitive(list[str], include_optional=False) is None
assert get_first_generatable_primitive(list[int], include_optional=False) is None
assert get_first_generatable_primitive(Mapped[list[str]], include_optional=False) is None
assert get_first_generatable_primitive(dict[str, int], include_optional=False) is None


def test_get_first_generatable_primitive_py310_optional():
Expand Down
Loading
Loading