Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/_pytask/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,29 @@
from typing import ParamSpec
from typing import Protocol
from typing import TypeVar
from typing import cast

from _pytask._hashlib import hash_value

P = ParamSpec("P")
R = TypeVar("R")

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeAlias

from ty_extensions import Intersection

Memoized: TypeAlias = "Intersection[Callable[P, R], HasCache]"

P = ParamSpec("P")
R = TypeVar("R")


class HasCache(Protocol):
"""Protocol for objects that have a cache attribute."""

cache: Cache


if TYPE_CHECKING:
Memoized = Intersection[Callable[P, R], HasCache]


@dataclass
class CacheInfo:
hits: int = 0
Expand Down Expand Up @@ -68,9 +70,10 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:

return value

wrapped.cache = self # ty: ignore[unresolved-attribute]
wrapped_with_cache = cast("Memoized[P, R]", wrapped)
wrapped_with_cache.cache = self

return wrapped
return wrapped_with_cache

def add(self, key: str, value: Any) -> None:
self._cache[key] = value
Expand Down
7 changes: 6 additions & 1 deletion src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def pytask_collect_task(
)

markers = get_all_marks(obj)
attributes: dict[str, Any]

if isinstance(obj, TaskFunction):
attributes = {
Expand All @@ -361,7 +362,11 @@ def pytask_collect_task(
"is_generator": obj.pytask_meta.is_generator,
}
else:
attributes = {"collection_id": None, "after": [], "is_generator": False}
attributes = {
"collection_id": None,
"after": [],
"is_generator": False,
}

unwrapped = unwrap_task_function(obj)
if isinstance(unwrapped, Function):
Expand Down
2 changes: 1 addition & 1 deletion src/_pytask/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def pytask_configure(pm: PluginManager, raw_config: dict[str, Any]) -> dict[str,
"""Configure pytask."""
# Add all values by default so that many plugins do not need to copy over values.
config = {"pm": pm, "markers": {}} | raw_config
config["markers"] = parse_markers(config["markers"]) # type: ignore[arg-type]
config["markers"] = parse_markers(config["markers"])

pm.hook.pytask_parse_config(config=config)
pm.hook.pytask_post_parse(config=config)
Expand Down
10 changes: 7 additions & 3 deletions tests/test_collect_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import textwrap
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import pytest

Expand All @@ -16,6 +17,9 @@
from pytask import cli
from tests.conftest import enter_directory

if TYPE_CHECKING:
from _pytask.node_protocols import PTaskWithPath


def test_collect_task(runner, tmp_path):
source = """
Expand Down Expand Up @@ -315,7 +319,7 @@ def function(depends_on, produces): ...


def test_print_collected_tasks_without_nodes(capsys):
dictionary = {
dictionary: dict[Path, list[PTaskWithPath]] = {
Path("task_path.py"): [
Task(
base_name="function",
Expand All @@ -337,7 +341,7 @@ def test_print_collected_tasks_without_nodes(capsys):


def test_print_collected_tasks_with_nodes(capsys):
dictionary = {
dictionary: dict[Path, list[PTaskWithPath]] = {
Path("task_path.py"): [
Task(
base_name="function",
Expand All @@ -361,7 +365,7 @@ def test_print_collected_tasks_with_nodes(capsys):

@pytest.mark.parametrize(("show_nodes", "expected_add"), [(False, "src"), (True, "..")])
def test_find_common_ancestor_of_all_nodes(show_nodes, expected_add):
tasks = [
tasks: list[PTaskWithPath] = [
Task(
base_name="function",
path=Path.cwd() / "src" / "task_path.py",
Expand Down
40 changes: 20 additions & 20 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.