diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9c0b020..b204926 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,7 +1,10 @@ name: Run tests and upload coverage on: - push + push: + branches: + - main + pull_request: jobs: test: diff --git a/tests/test_distillation.py b/tests/test_distillation.py index d2e2c5d..3c50ae6 100644 --- a/tests/test_distillation.py +++ b/tests/test_distillation.py @@ -8,8 +8,8 @@ import numpy as np import pytest from pytest import LogCaptureFixture -from transformers import BertTokenizerFast from transformers.modeling_utils import PreTrainedModel +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast from model2vec.distill.distillation import distill, distill_from_model from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings @@ -42,7 +42,7 @@ def test_distill_from_model( mock_auto_model: MagicMock, mock_model_info: MagicMock, - mock_berttokenizer: BertTokenizerFast, + mock_berttokenizer: PreTrainedTokenizerFast, mock_transformer: PreTrainedModel, vocabulary: list[str] | None, pca_dims: int | None, @@ -83,7 +83,7 @@ def test_distill_from_model( def test_distill_removal_pattern( mock_auto_model: MagicMock, mock_model_info: MagicMock, - mock_berttokenizer: BertTokenizerFast, + mock_berttokenizer: PreTrainedTokenizerFast, mock_transformer: PreTrainedModel, ) -> None: """Test the removal pattern.""" @@ -180,10 +180,12 @@ def test_distill( def test_missing_modelinfo( mock_model_info: MagicMock, mock_transformer: PreTrainedModel, - mock_berttokenizer: BertTokenizerFast, + mock_berttokenizer: PreTrainedTokenizerFast, ) -> None: """Test that missing model info does not crash.""" - mock_model_info.side_effect = RepositoryNotFoundError("Model not found") + mock_response = MagicMock() + mock_response.status_code = 404 + mock_model_info.side_effect = RepositoryNotFoundError("Model not found", response=mock_response) static_model = distill_from_model(model=mock_transformer, tokenizer=mock_berttokenizer, device="cpu") assert static_model.language is None @@ -237,7 +239,7 @@ def test__post_process_embeddings( ], ) def test_clean_and_create_vocabulary( - mock_berttokenizer: BertTokenizerFast, + mock_berttokenizer: PreTrainedTokenizerFast, added_tokens: list[str], expected_output: list[str], expected_warnings: list[str],