Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions .github/workflows/compat.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13"]
integration: ["onnx", "torch", "timm", "tensorflow"]
integration: ["onnx", "torch", "timm", "tensorflow", "gguf", "openai", "transformers", "ultralytics"]
version-set: ["min", "current"]
exclude:
- python-version: "3.13"
Expand All @@ -44,7 +44,7 @@ jobs:
- name: Install project
run: |
python -m pip install --upgrade pip
pip install -e . pytest
pip install -e . pytest pytest-asyncio pytest-mock
- name: Install compatibility dependencies
run: |
python scripts/compat_matrix.py deps \
Expand All @@ -55,6 +55,34 @@ jobs:
- name: Run compatibility test
run: python -m pytest "tests/compat/test_${{ matrix.integration }}_compat.py" -q

compat-mlx:
if: github.event_name == 'schedule' || github.event_name == 'workflow_dispatch'
runs-on: macos-14
strategy:
fail-fast: false
matrix:
python-version: ["3.12", "3.13"]
version-set: ["current"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install project
run: |
python -m pip install --upgrade pip
pip install -e . pytest pytest-asyncio pytest-mock
- name: Install compatibility dependencies
run: |
python scripts/compat_matrix.py deps \
--integration "mlx" \
--version-set "${{ matrix.version-set }}" \
--python-version "${{ matrix.python-version }}" > compat-requirements.txt
pip install -r compat-requirements.txt
- name: Run compatibility test
run: python -m pytest tests/compat/test_mlx_compat.py -q

compat-canary-314:
if: github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'compat')
runs-on: ubuntu-latest
Expand All @@ -75,7 +103,7 @@ jobs:
- name: Install project
run: |
python -m pip install --upgrade pip
pip install -e . pytest
pip install -e . pytest pytest-asyncio pytest-mock
- name: Install compatibility dependencies
run: |
python scripts/compat_matrix.py deps \
Expand Down
44 changes: 44 additions & 0 deletions scripts/compat_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import argparse
import json
import sys

MATRIX = {
Expand Down Expand Up @@ -33,13 +34,37 @@
"min": ["tensorflow==2.16.1", "keras==3.3.3", "numpy==1.26.4"],
"current": ["tensorflow==2.18.0", "keras==3.8.0", "numpy==2.0.2"],
},
"gguf": {
"min": ["llama-cpp-python==0.2.90", "numpy==1.26.4"],
"current": ["llama-cpp-python==0.3.4", "numpy==2.1.3"],
},
"openai": {
"min": ["openai==1.30.0"],
"current": ["openai==1.61.0"],
},
"transformers": {
"min": ["transformers==4.40.0", "torch==2.4.1", "numpy==1.26.4"],
"current": ["transformers==4.47.0", "torch==2.6.0", "numpy==2.1.3"],
},
"ultralytics": {
"min": ["ultralytics==8.0.0", "numpy==1.26.4"],
"current": ["ultralytics==8.3.4", "numpy==2.1.3"],
},
"mlx": {
"current": ["mlx==0.22.0", "mlx-lm==0.21.0"],
},
}

SUPPORTED_PYTHON = {
"onnx": ["3.10", "3.11", "3.12", "3.13", "3.14"],
"torch": ["3.10", "3.11", "3.12", "3.13", "3.14"],
"timm": ["3.10", "3.11", "3.12", "3.13", "3.14"],
"tensorflow": ["3.10", "3.11", "3.12"],
"gguf": ["3.10", "3.11", "3.12", "3.13"],
"openai": ["3.10", "3.11", "3.12", "3.13"],
"transformers": ["3.10", "3.11", "3.12", "3.13"],
"ultralytics": ["3.10", "3.11", "3.12", "3.13"],
"mlx": ["3.12", "3.13"],
}

# Interpreter-specific overrides where upstream wheels are unavailable for older pins.
Expand All @@ -58,6 +83,11 @@
"3.14": ["torch==2.10.0", "numpy==2.1.3"],
},
},
"transformers": {
"min": {
"3.13": ["transformers==4.45.0", "torch==2.5.0", "numpy==2.1.3"],
}
},
"timm": {
"min": {
"3.13": [
Expand Down Expand Up @@ -113,6 +143,17 @@ def print_deps(integration: str, version_set: str, python_version: str) -> int:
return 0


def print_rows() -> int:
rows = [
{"integration": integration, "version_set": version_set, "python_version": py}
for integration, sets in MATRIX.items()
for version_set in sets
for py in SUPPORTED_PYTHON[integration]
]
print(json.dumps(rows))
return 0


def print_table() -> int:
print("| Integration | Version set | Dependencies | Supported Python |")
print("|---|---|---|---|")
Expand Down Expand Up @@ -140,11 +181,14 @@ def main() -> int:
deps.add_argument("--version-set", required=True, choices=["min", "current"])
deps.add_argument("--python-version", required=True)

sub.add_parser("rows")
sub.add_parser("table")

args = parser.parse_args()
if args.cmd == "deps":
return print_deps(args.integration, args.version_set, args.python_version)
if args.cmd == "rows":
return print_rows()
if args.cmd == "table":
return print_table()
return 2
Expand Down
86 changes: 55 additions & 31 deletions scripts/run_compat_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from __future__ import annotations

import argparse
import json
import os
import subprocess
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path

Expand All @@ -19,18 +23,18 @@ class Row:
version_set: str


WORKFLOW_ROWS: list[Row] = [
# compat job
*[
Row(py, integration, version_set)
for py in ("3.10", "3.11", "3.12", "3.13")
for integration in ("onnx", "torch", "timm", "tensorflow")
for version_set in ("min", "current")
if not (py == "3.13" and integration == "tensorflow")
],
# compat-canary-314 job
*[Row("3.14", integration, "current") for integration in ("torch", "timm")],
]
def load_rows() -> list[Row]:
result = subprocess.run(
[sys.executable, str(REPO_ROOT / "scripts" / "compat_matrix.py"), "rows"],
check=True,
capture_output=True,
text=True,
cwd=REPO_ROOT,
)
return [Row(**r) for r in json.loads(result.stdout)]


WORKFLOW_ROWS: list[Row] = load_rows()


UNSUPPORTED_MARKERS = (
Expand Down Expand Up @@ -67,10 +71,15 @@ def run_row(row: Row) -> tuple[str, str]:
"run",
"--python",
row.python_version,
"--link-mode=copy",
"--with-editable",
".",
"--with",
"pytest",
"--with",
"pytest-asyncio",
"--with",
"pytest-mock",
]
for dep in deps:
cmd.extend(["--with", dep])
Expand All @@ -84,9 +93,11 @@ def run_row(row: Row) -> tuple[str, str]:
]
)

result = subprocess.run(
cmd, check=False, capture_output=True, text=True, cwd=REPO_ROOT
)
with tempfile.TemporaryDirectory() as tmpdir:
env = {**os.environ, "UV_PROJECT_ENVIRONMENT": tmpdir}
result = subprocess.run(
cmd, check=False, capture_output=True, text=True, cwd=REPO_ROOT, env=env
)
output = f"{result.stdout}\n{result.stderr}".strip()

if result.returncode == 0:
Expand All @@ -103,29 +114,42 @@ def main() -> int:
action="store_true",
help="Treat unsupported dependency rows as failures.",
)
parser.add_argument(
"--jobs",
"-j",
type=int,
default=os.cpu_count() or 4,
help="Number of rows to run in parallel (default: cpu count).",
)
args = parser.parse_args()

passed = 0
failed = 0
skipped = 0

for row in WORKFLOW_ROWS:
label = f"{row.python_version} | {row.integration} | {row.version_set}"
print(f"==> {label}")
status, output = run_row(row)
print(status)
if output:
print(output)
print()

if status == "PASS":
passed += 1
elif status == "SKIP_UNSUPPORTED":
skipped += 1
if args.strict_unsupported:
futures = {}
with ThreadPoolExecutor(max_workers=args.jobs) as executor:
for row in WORKFLOW_ROWS:
futures[executor.submit(run_row, row)] = row

for future in as_completed(futures):
row = futures[future]
label = f"{row.python_version} | {row.integration} | {row.version_set}"
status, output = future.result()
print(f"==> {label}")
print(status)
if output:
print(output)
print()

if status == "PASS":
passed += 1
elif status == "SKIP_UNSUPPORTED":
skipped += 1
if args.strict_unsupported:
failed += 1
else:
failed += 1
else:
failed += 1

print(
"SUMMARY "
Expand Down
9 changes: 9 additions & 0 deletions tests/compat/test_gguf_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import pytest


def test_gguf_import_and_instrument(compat_client):
llama_cpp = pytest.importorskip("llama_cpp")
assert hasattr(llama_cpp, "Llama")
compat_client.instrument("gguf")
17 changes: 17 additions & 0 deletions tests/compat/test_mlx_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from __future__ import annotations

import pytest


def test_mlx_import_and_instrumentation(compat_client):
pytest.importorskip("mlx_lm")
mx = pytest.importorskip("mlx.core")
nn = pytest.importorskip("mlx.nn")

compat_client.instrument("mlx")

model = nn.Linear(4, 2)
x = mx.ones((3, 4))
y = model(x)
mx.eval(y)
assert y.shape == (3, 2)
9 changes: 9 additions & 0 deletions tests/compat/test_openai_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import pytest


def test_openai_import_and_instrument(compat_client):
openai = pytest.importorskip("openai")
assert getattr(openai, "__version__", None)
compat_client.instrument("openai")
25 changes: 25 additions & 0 deletions tests/compat/test_transformers_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

import pytest


def test_transformers_import_and_instrumentation(compat_client):
torch = pytest.importorskip("torch")
transformers = pytest.importorskip("transformers")

compat_client.instrument("transformers")

config = transformers.BertConfig(
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=2,
intermediate_size=64,
num_labels=2,
)
model = transformers.BertForSequenceClassification(config)
inputs = {
"input_ids": torch.zeros((1, 4), dtype=torch.long),
"attention_mask": torch.ones((1, 4), dtype=torch.long),
}
out = model(**inputs)
assert tuple(out.logits.shape) == (1, 2)
9 changes: 9 additions & 0 deletions tests/compat/test_ultralytics_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

import pytest


def test_ultralytics_import_and_instrument(compat_client):
ultralytics = pytest.importorskip("ultralytics")
assert getattr(ultralytics, "__version__", None)
compat_client.instrument("ultralytics")
Loading