Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
name: Run tests and upload coverage

on:
push
push:
branches:
- main
pull_request:

jobs:
test:
Expand Down
14 changes: 8 additions & 6 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
import pytest
from pytest import LogCaptureFixture
from transformers import BertTokenizerFast
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast

from model2vec.distill.distillation import distill, distill_from_model
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
Expand Down Expand Up @@ -42,7 +42,7 @@
def test_distill_from_model(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_berttokenizer: PreTrainedTokenizerFast,
mock_transformer: PreTrainedModel,
vocabulary: list[str] | None,
pca_dims: int | None,
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_distill_from_model(
def test_distill_removal_pattern(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: BertTokenizerFast,
mock_berttokenizer: PreTrainedTokenizerFast,
mock_transformer: PreTrainedModel,
) -> None:
"""Test the removal pattern."""
Expand Down Expand Up @@ -180,10 +180,12 @@ def test_distill(
def test_missing_modelinfo(
mock_model_info: MagicMock,
mock_transformer: PreTrainedModel,
mock_berttokenizer: BertTokenizerFast,
mock_berttokenizer: PreTrainedTokenizerFast,
) -> None:
"""Test that missing model info does not crash."""
mock_model_info.side_effect = RepositoryNotFoundError("Model not found")
mock_response = MagicMock()
mock_response.status_code = 404
mock_model_info.side_effect = RepositoryNotFoundError("Model not found", response=mock_response)
static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu")
assert static_model.language is None

Expand Down Expand Up @@ -237,7 +239,7 @@ def test__post_process_embeddings(
],
)
def test_clean_and_create_vocabulary(
mock_berttokenizer: BertTokenizerFast,
mock_berttokenizer: PreTrainedTokenizerFast,
added_tokens: list[str],
expected_output: list[str],
expected_warnings: list[str],
Expand Down