diff --git a/crates/ci-core/src/ci_tests/mod.rs b/crates/ci-core/src/ci_tests/mod.rs index 0576244..154f5a1 100644 --- a/crates/ci-core/src/ci_tests/mod.rs +++ b/crates/ci-core/src/ci_tests/mod.rs @@ -20,34 +20,34 @@ use crate::registry::Registry; pub fn register_all_tests(registry: &mut Registry) { registry - .add_to_registry("chi_square", ChiSquared {}) + .add_to_registry("chi_square", || Box::new(ChiSquared {})) .expect("Failed to register Chi Square test!"); registry - .add_to_registry("g_test", GTest {}) + .add_to_registry("g_test", || Box::new(GTest {})) .expect("Failed to register GTest!"); registry - .add_to_registry("independence_match", IndependenceMatch {}) + .add_to_registry("independence_match", || Box::new(IndependenceMatch {})) .expect("Failed to register Independence Match test!"); registry - .add_to_registry("likelihood_ratio", LikelihoodRatio {}) + .add_to_registry("likelihood_ratio", || Box::new(LikelihoodRatio {})) .expect("Failed to register Likelihood Ratio test!"); registry - .add_to_registry("modified_likelihood", ModifiedLikelihood {}) + .add_to_registry("modified_likelihood", || Box::new(ModifiedLikelihood {})) .expect("Failed to register Modified Likelihood tTest!"); registry - .add_to_registry("pearson_correlation", PearsonCorrelation {}) + .add_to_registry("pearson_correlation", || Box::new(PearsonCorrelation {})) .expect("Failed to register Pearson Correlation test!"); registry - .add_to_registry("pearson_equivalence", PearsonEquivalence {}) + .add_to_registry("pearson_equivalence", || Box::new(PearsonEquivalence {})) .expect("Failed to register Pearson Equivalence test!"); registry - .add_to_registry("power_divergence", PowerDivergence {}) + .add_to_registry("power_divergence", || Box::new(PowerDivergence {})) .expect("Failed to register Power Divergence test!"); } diff --git a/crates/ci-core/src/registry.rs b/crates/ci-core/src/registry.rs index 959e19d..3a2e40c 100644 --- a/crates/ci-core/src/registry.rs +++ b/crates/ci-core/src/registry.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; /// The registry maintains a collection of test implementations that can be retrieved /// by name. pub struct Registry { - pub tests: HashMap>, + pub tests: HashMap Box>, } impl Registry { @@ -32,12 +32,13 @@ impl Registry { /// /// # Errors /// Returns an error if the test name is not found in the registry. - pub fn get_test(&self, test_name: &str) -> anyhow::Result<&dyn CITest> { + pub fn get_test(&self, test_name: &str) -> anyhow::Result> { let test_name = test_name.to_lowercase(); - self.tests - .get(&test_name) - .map(std::convert::AsRef::as_ref) - .ok_or_else(|| anyhow::anyhow!("Test '{test_name}' not found!")) + let test = self.tests.get(&test_name); + match test { + Some(t) => Ok(t()), + None => Err(anyhow::anyhow!("Test '{test_name}' not found!")), + } } /// Returns a list of all registered test names. @@ -71,14 +72,13 @@ impl Registry { pub fn add_to_registry( &mut self, test_name: &str, - test: impl CITest + 'static, + test: fn() -> Box, ) -> anyhow::Result<()> { let test_name = test_name.to_lowercase(); if self.tests.contains_key(&test_name) { anyhow::bail!("Test already exists in registry!"); - } - let ci_test = Box::new(test); - self.tests.insert(test_name, ci_test); + }; + self.tests.insert(test_name, test); Ok(()) } } diff --git a/crates/ci-python/Cargo.toml b/crates/ci-python/Cargo.toml index 4b9ff1a..ebb0abe 100644 --- a/crates/ci-python/Cargo.toml +++ b/crates/ci-python/Cargo.toml @@ -17,4 +17,5 @@ ndarray = { workspace = true } workspace = true [lib] +name = "_ci_python" crate-type = ["cdylib"] diff --git a/crates/ci-python/ci_python/__init__.py b/crates/ci-python/ci_python/__init__.py new file mode 100644 index 0000000..6957bf8 --- /dev/null +++ b/crates/ci-python/ci_python/__init__.py @@ -0,0 +1,7 @@ +"""Python bindings for conditional independence testing.""" + +from ._ci_python import Registry + +__all__ = [ + "Registry", +] diff --git a/crates/ci-python/ci_python/_ci_python.pyi b/crates/ci-python/ci_python/_ci_python.pyi new file mode 100644 index 0000000..cd49745 --- /dev/null +++ b/crates/ci-python/ci_python/_ci_python.pyi @@ -0,0 +1,18 @@ +from collections.abc import Callable + +import numpy as np + +class Registry: + def __init__(self) -> None: ... + def get_test( + self, test_name: str + ) -> Callable[ + [ + np.ndarray[tuple[int], np.dtype[np.float64]], + np.ndarray[tuple[int], np.dtype[np.float64]], + np.ndarray[tuple[int, int], np.dtype[np.float64]], + bool, + ], + bool | tuple[float, float] | tuple[float, float, float], + ]: ... + def list_all(self) -> list[str]: ... diff --git a/crates/ci-python/pyproject.toml b/crates/ci-python/pyproject.toml new file mode 100644 index 0000000..f47b283 --- /dev/null +++ b/crates/ci-python/pyproject.toml @@ -0,0 +1,60 @@ +[project] +name = "ci-python" +version = "0.1.0" +requires-python = ">=3.10,<3.15" +dependencies = [] +classifiers = [ + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Operating System :: Unix", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", +] + +[build-system] +requires = ["maturin>=1.11,<2.0"] +build-backend = "maturin" + +[tool.maturin] +module-name = "ci_python._ci_python" + +[project.optional-dependencies] +test = [ + "pytest", +] + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +select = ["ALL"] +ignore = [ + "ANN002", # missing-type-args: Checks that function *args arguments have type annotations. + "ANN003", # missing-type-kwargs: Checks that function **kwargs arguments have type annotations. + "COM812", # missing-trailing-comma: Checks for the absence of trailing commas. + "E741", # ambiguous-variable-name: Checks for the use of the characters 'l', 'O', or 'I' as variable names. + "EM101", # raw-string-in-exception: Checks for the use of string literals in exception constructors. + "EM102", # f-string-in-exception: Checks for the use of f-strings in exception constructors. + "F403", # undefined-local-with-import-star: Checks for the use of wildcard imports. + "FBT001", # boolean-type-hint-positional-argument: Checks for the use of boolean positional arguments in function definitions, as determined by the presence of a type hint containing bool as an evident subtype - e.g. bool, bool | int, typing.Optional[bool], etc. + "FBT002", # boolean-default-value-positional-argument: Checks for the use of boolean positional arguments in function definitions, as determined by the presence of a boolean default value. + "FBT003", # boolean-positional-value-in-call: Checks for boolean positional arguments in function calls. + "N812", # lowercase-imported-as-non-lowercase: Checks for lowercase imports that are aliased to non-lowercase names. + "PLR0913", # too-many-arguments: Checks for function definitions that include too many arguments. + "PLR1714", # repeated-equality-comparison: Checks for repeated equality comparisons that can be rewritten using the in operator. + "PLR2004", # magic-value-comparison: Checks for the use of unnamed numerical constants ("magic") values in comparisons. + "S101", # assert: Checks for uses of the assert keyword. + "SIM108", # if-else-block-instead-of-if-exp: Check for if-else-blocks that can be replaced with a ternary or binary operator. + "TD002", # missing-todo-author: Checks that a TODO comment includes an author. + "TD003", # missing-todo-link: Checks that a TODO comment is associated with a link to a relevant issue or ticket. + "TRY003", # raise-vanilla-args: Checks for long exception messages that are not defined in the exception class itself. +] + +[tool.ruff.lint.pydocstyle] +convention = "google" + +[tool.ruff.format] +quote-style = "preserve" diff --git a/crates/ci-python/src/lib.rs b/crates/ci-python/src/lib.rs index dd53c00..909def4 100644 --- a/crates/ci-python/src/lib.rs +++ b/crates/ci-python/src/lib.rs @@ -1,83 +1,62 @@ -use ci_core::registry::Registry; -use ci_core::strategy::TestResult; -use ndarray::{Array1, Array2}; -use numpy::{PyReadonlyArray1, PyReadonlyArray2}; use pyo3::prelude::*; -use std::sync::Arc; +use ci_core::{registry::Registry, strategy::CITest, strategy::TestResult}; +use numpy::{PyReadonlyArray1, PyReadonlyArray2}; -#[pyclass(frozen)] -pub struct PyRegistry(Arc); +#[pyclass(frozen, name = "Registry")] +pub struct PyRegistry(Registry); #[pymethods] impl PyRegistry { #[new] - #[must_use] - #[allow(clippy::new_without_default)] pub fn new() -> Self { - Self(Arc::new(Registry::new())) + Self(Registry::new()) + } + + fn get_test(&self, test_name: &str) -> PyResult { + Ok(PyCITest { + test: self + .0 + .get_test(test_name) + .map_err(|e| PyErr::new::(e.to_string()))?, + }) } - fn list_all_tests(&self) -> PyResult> { + fn list_all(&self) -> PyResult> { let tests = self .0 .list_all_tests() .map_err(|e| PyErr::new::(e.to_string()))?; Ok(tests.into_iter().cloned().collect()) } - - fn get_test(&self, test_name: &str) -> PyResult { - self.0 - .get_test(test_name) - .map_err(|e| PyErr::new::(e.to_string()))?; - Ok(PyCITest { - registry: self.0.clone(), - test_name: test_name.to_string(), - }) - } } -#[pyclass(frozen)] -pub struct PyCITest { - registry: Arc, - test_name: String, +#[pyclass(frozen, name = "CITest")] +struct PyCITest { + test: Box, } #[pymethods] impl PyCITest { - /// Run the conditional independence test on the given data. - /// - /// # Errors - /// - /// Returns `PyRuntimeError` if the test lookup fails or the test itself returns an error. - #[allow(clippy::needless_pass_by_value)] - #[pyo3(signature = (array, x, y, boolean=true))] pub fn __call__( &self, - py: Python<'_>, - array: PyReadonlyArray2<'_, f64>, - x: PyReadonlyArray1<'_, f64>, - y: PyReadonlyArray1<'_, f64>, + py: Python, + array: PyReadonlyArray2, + x_value: PyReadonlyArray1, + y_value: PyReadonlyArray1, boolean: bool, ) -> PyResult> { - let array: Array2 = array.as_array().to_owned(); - let x: Array1 = x.as_array().to_owned(); - let y: Array1 = y.as_array().to_owned(); - - let test = self - .registry - .get_test(&self.test_name) - .map_err(|e| PyErr::new::(e.to_string()))?; + let arr = array.as_array().to_owned(); + let x = x_value.as_array().to_owned(); + let y = y_value.as_array().to_owned(); - let result = test - .run_test(array, x, y, boolean) + let result = self + .test + .run_test(arr, x, y, boolean) .map_err(|e| PyErr::new::(e.to_string()))?; match result { + TestResult::Correlated(Ok(t)) => Ok(t.into_pyobject(py)?.into_any().unbind()), TestResult::Boolean(Ok(b)) => Ok(b.into_pyobject(py)?.to_owned().into_any().unbind()), - TestResult::Correlated(Ok((p_value, coefficient))) => Ok((p_value, coefficient) - .into_pyobject(py)? - .into_any() - .unbind()), TestResult::Boolean(Err(e)) | TestResult::Correlated(Err(e)) => { Err(PyErr::new::( e.to_string(), @@ -88,7 +67,7 @@ impl PyCITest { } #[pymodule] -fn ci_python(m: &Bound<'_, PyModule>) -> PyResult<()> { +fn _ci_python(_py: Python, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; Ok(()) diff --git a/crates/ci-python/test/__init__.py b/crates/ci-python/test/__init__.py new file mode 100644 index 0000000..dab3f0c --- /dev/null +++ b/crates/ci-python/test/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ci_python.""" diff --git a/crates/ci-python/test/test_Registry.py b/crates/ci-python/test/test_Registry.py new file mode 100644 index 0000000..dc9c8b6 --- /dev/null +++ b/crates/ci-python/test/test_Registry.py @@ -0,0 +1,8 @@ +import pytest +from ci_python import Registry + + +class TestRegistry: + """Tests for the :class:`Registry`.""" + + # TODO: https://github.com/GiPHouse/Conditional-Independence-Testing/issues/142 diff --git a/crates/ci-python/tests/test.py b/crates/ci-python/tests/test.py index ec5235b..845cef0 100644 --- a/crates/ci-python/tests/test.py +++ b/crates/ci-python/tests/test.py @@ -1,11 +1,11 @@ from pgmpy.estimators.CITests import pearsonr import numpy as np -from ci_python import PyRegistry +from ci_python import Registry import time import pandas as pd -registry = PyRegistry() +registry = Registry() test = registry.get_test("pearson_correlation") N_ITER = 50