Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions crates/ci-core/src/ci_tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,34 @@ use crate::registry::Registry;

pub fn register_all_tests(registry: &mut Registry) {
registry
.add_to_registry("chi_square", ChiSquared {})
.add_to_registry("chi_square", || Box::new(ChiSquared {}))
.expect("Failed to register Chi Square test!");

registry
.add_to_registry("g_test", GTest {})
.add_to_registry("g_test", || Box::new(GTest {}))
.expect("Failed to register GTest!");

registry
.add_to_registry("independence_match", IndependenceMatch {})
.add_to_registry("independence_match", || Box::new(IndependenceMatch {}))
.expect("Failed to register Independence Match test!");

registry
.add_to_registry("likelihood_ratio", LikelihoodRatio {})
.add_to_registry("likelihood_ratio", || Box::new(LikelihoodRatio {}))
.expect("Failed to register Likelihood Ratio test!");

registry
.add_to_registry("modified_likelihood", ModifiedLikelihood {})
.add_to_registry("modified_likelihood", || Box::new(ModifiedLikelihood {}))
.expect("Failed to register Modified Likelihood tTest!");

registry
.add_to_registry("pearson_correlation", PearsonCorrelation {})
.add_to_registry("pearson_correlation", || Box::new(PearsonCorrelation {}))
.expect("Failed to register Pearson Correlation test!");

registry
.add_to_registry("pearson_equivalence", PearsonEquivalence {})
.add_to_registry("pearson_equivalence", || Box::new(PearsonEquivalence {}))
.expect("Failed to register Pearson Equivalence test!");

registry
.add_to_registry("power_divergence", PowerDivergence {})
.add_to_registry("power_divergence", || Box::new(PowerDivergence {}))
.expect("Failed to register Power Divergence test!");
}
20 changes: 10 additions & 10 deletions crates/ci-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::collections::HashMap;
/// The registry maintains a collection of test implementations that can be retrieved
/// by name.
pub struct Registry {
pub tests: HashMap<String, Box<dyn CITest>>,
pub tests: HashMap<String, fn() -> Box<dyn CITest>>,
}

impl Registry {
Expand All @@ -32,12 +32,13 @@ impl Registry {
///
/// # Errors
/// Returns an error if the test name is not found in the registry.
pub fn get_test(&self, test_name: &str) -> anyhow::Result<&dyn CITest> {
pub fn get_test(&self, test_name: &str) -> anyhow::Result<Box<dyn CITest>> {
let test_name = test_name.to_lowercase();
self.tests
.get(&test_name)
.map(std::convert::AsRef::as_ref)
.ok_or_else(|| anyhow::anyhow!("Test '{test_name}' not found!"))
let test = self.tests.get(&test_name);
match test {
Some(t) => Ok(t()),
None => Err(anyhow::anyhow!("Test '{test_name}' not found!")),
}
}

/// Returns a list of all registered test names.
Expand Down Expand Up @@ -71,14 +72,13 @@ impl Registry {
pub fn add_to_registry(
&mut self,
test_name: &str,
test: impl CITest + 'static,
test: fn() -> Box<dyn CITest>,
) -> anyhow::Result<()> {
let test_name = test_name.to_lowercase();
if self.tests.contains_key(&test_name) {
anyhow::bail!("Test already exists in registry!");
}
let ci_test = Box::new(test);
self.tests.insert(test_name, ci_test);
};
self.tests.insert(test_name, test);
Ok(())
}
}
Expand Down
1 change: 1 addition & 0 deletions crates/ci-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ ndarray = { workspace = true }
workspace = true

[lib]
name = "_ci_python"
crate-type = ["cdylib"]
7 changes: 7 additions & 0 deletions crates/ci-python/ci_python/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Python bindings for conditional independence testing."""

from ._ci_python import Registry

__all__ = [
"Registry",
]
18 changes: 18 additions & 0 deletions crates/ci-python/ci_python/_ci_python.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from collections.abc import Callable

import numpy as np

class Registry:
def __init__(self) -> None: ...
def get_test(
self, test_name: str
) -> Callable[
[
np.ndarray[tuple[int], np.dtype[np.float64]],
np.ndarray[tuple[int], np.dtype[np.float64]],
np.ndarray[tuple[int, int], np.dtype[np.float64]],
bool,
],
bool | tuple[float, float] | tuple[float, float, float],
]: ...
def list_all(self) -> list[str]: ...
60 changes: 60 additions & 0 deletions crates/ci-python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
[project]
name = "ci-python"
version = "0.1.0"
requires-python = ">=3.10,<3.15"
dependencies = []
classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
"Operating System :: Unix",
"Operating System :: Microsoft :: Windows",
"Operating System :: MacOS",
]

[build-system]
requires = ["maturin>=1.11,<2.0"]
build-backend = "maturin"

[tool.maturin]
module-name = "ci_python._ci_python"

[project.optional-dependencies]
test = [
"pytest",
]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
select = ["ALL"]
ignore = [
"ANN002", # missing-type-args: Checks that function *args arguments have type annotations.
"ANN003", # missing-type-kwargs: Checks that function **kwargs arguments have type annotations.
"COM812", # missing-trailing-comma: Checks for the absence of trailing commas.
"E741", # ambiguous-variable-name: Checks for the use of the characters 'l', 'O', or 'I' as variable names.
"EM101", # raw-string-in-exception: Checks for the use of string literals in exception constructors.
"EM102", # f-string-in-exception: Checks for the use of f-strings in exception constructors.
"F403", # undefined-local-with-import-star: Checks for the use of wildcard imports.
"FBT001", # boolean-type-hint-positional-argument: Checks for the use of boolean positional arguments in function definitions, as determined by the presence of a type hint containing bool as an evident subtype - e.g. bool, bool | int, typing.Optional[bool], etc.
"FBT002", # boolean-default-value-positional-argument: Checks for the use of boolean positional arguments in function definitions, as determined by the presence of a boolean default value.
"FBT003", # boolean-positional-value-in-call: Checks for boolean positional arguments in function calls.
"N812", # lowercase-imported-as-non-lowercase: Checks for lowercase imports that are aliased to non-lowercase names.
"PLR0913", # too-many-arguments: Checks for function definitions that include too many arguments.
"PLR1714", # repeated-equality-comparison: Checks for repeated equality comparisons that can be rewritten using the in operator.
"PLR2004", # magic-value-comparison: Checks for the use of unnamed numerical constants ("magic") values in comparisons.
"S101", # assert: Checks for uses of the assert keyword.
"SIM108", # if-else-block-instead-of-if-exp: Check for if-else-blocks that can be replaced with a ternary or binary operator.
"TD002", # missing-todo-author: Checks that a TODO comment includes an author.
"TD003", # missing-todo-link: Checks that a TODO comment is associated with a link to a relevant issue or ticket.
"TRY003", # raise-vanilla-args: Checks for long exception messages that are not defined in the exception class itself.
]

[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.ruff.format]
quote-style = "preserve"
81 changes: 30 additions & 51 deletions crates/ci-python/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,83 +1,62 @@
use ci_core::registry::Registry;
use ci_core::strategy::TestResult;
use ndarray::{Array1, Array2};
use numpy::{PyReadonlyArray1, PyReadonlyArray2};
use pyo3::prelude::*;
use std::sync::Arc;
use ci_core::{registry::Registry, strategy::CITest, strategy::TestResult};
use numpy::{PyReadonlyArray1, PyReadonlyArray2};

#[pyclass(frozen)]
pub struct PyRegistry(Arc<Registry>);
#[pyclass(frozen, name = "Registry")]
pub struct PyRegistry(Registry);

#[pymethods]
impl PyRegistry {
#[new]
#[must_use]
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
Self(Arc::new(Registry::new()))
Self(Registry::new())
}

fn get_test(&self, test_name: &str) -> PyResult<PyCITest> {
Ok(PyCITest {
test: self
.0
.get_test(test_name)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?,
})
}

fn list_all_tests(&self) -> PyResult<Vec<String>> {
fn list_all(&self) -> PyResult<Vec<String>> {
let tests = self
.0
.list_all_tests()
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
Ok(tests.into_iter().cloned().collect())
}

fn get_test(&self, test_name: &str) -> PyResult<PyCITest> {
self.0
.get_test(test_name)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
Ok(PyCITest {
registry: self.0.clone(),
test_name: test_name.to_string(),
})
}
}

#[pyclass(frozen)]
pub struct PyCITest {
registry: Arc<Registry>,
test_name: String,
#[pyclass(frozen, name = "CITest")]
struct PyCITest {
test: Box<dyn CITest>,
}

#[pymethods]
impl PyCITest {
/// Run the conditional independence test on the given data.
///
/// # Errors
///
/// Returns `PyRuntimeError` if the test lookup fails or the test itself returns an error.
#[allow(clippy::needless_pass_by_value)]
#[pyo3(signature = (array, x, y, boolean=true))]
pub fn __call__(
&self,
py: Python<'_>,
array: PyReadonlyArray2<'_, f64>,
x: PyReadonlyArray1<'_, f64>,
y: PyReadonlyArray1<'_, f64>,
py: Python,
array: PyReadonlyArray2<f64>,
x_value: PyReadonlyArray1<f64>,
y_value: PyReadonlyArray1<f64>,
boolean: bool,
) -> PyResult<Py<PyAny>> {
let array: Array2<f64> = array.as_array().to_owned();
let x: Array1<f64> = x.as_array().to_owned();
let y: Array1<f64> = y.as_array().to_owned();

let test = self
.registry
.get_test(&self.test_name)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
let arr = array.as_array().to_owned();
let x = x_value.as_array().to_owned();
let y = y_value.as_array().to_owned();

let result = test
.run_test(array, x, y, boolean)
let result = self
.test
.run_test(arr, x, y, boolean)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;

match result {
TestResult::Correlated(Ok(t)) => Ok(t.into_pyobject(py)?.into_any().unbind()),
TestResult::Boolean(Ok(b)) => Ok(b.into_pyobject(py)?.to_owned().into_any().unbind()),
TestResult::Correlated(Ok((p_value, coefficient))) => Ok((p_value, coefficient)
.into_pyobject(py)?
.into_any()
.unbind()),
TestResult::Boolean(Err(e)) | TestResult::Correlated(Err(e)) => {
Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
e.to_string(),
Expand All @@ -88,7 +67,7 @@ impl PyCITest {
}

#[pymodule]
fn ci_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
fn _ci_python(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_class::<PyRegistry>()?;
m.add_class::<PyCITest>()?;
Ok(())
Expand Down
1 change: 1 addition & 0 deletions crates/ci-python/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for ci_python."""
8 changes: 8 additions & 0 deletions crates/ci-python/test/test_Registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from ci_python import Registry


class TestRegistry:
"""Tests for the :class:`Registry`."""

# TODO: https://github.com/GiPHouse/Conditional-Independence-Testing/issues/142
4 changes: 2 additions & 2 deletions crates/ci-python/tests/test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pgmpy.estimators.CITests import pearsonr
import numpy as np
from ci_python import PyRegistry
from ci_python import Registry
import time
import pandas as pd


registry = PyRegistry()
registry = Registry()
test = registry.get_test("pearson_correlation")

N_ITER = 50
Expand Down