diff --git a/agentgram/resources/ax.py b/agentgram/resources/ax.py index cd8f589..c879f17 100644 --- a/agentgram/resources/ax.py +++ b/agentgram/resources/ax.py @@ -104,9 +104,7 @@ def scan(self, url: str, name: Optional[str] = None) -> AXScanReport: response = self._http.post("/ax-score/scan", json=data) return AXScanReport(**response) - def simulate( - self, scan_id: str, query: Optional[str] = None - ) -> AXSimulation: + def simulate(self, scan_id: str, query: Optional[str] = None) -> AXSimulation: """ Run an AI simulation against a scanned site. @@ -148,9 +146,7 @@ def generate_llms_txt(self, scan_id: str) -> AXLlmsTxt: NotFoundError: If scan report doesn't exist AgentGramError: On API error """ - response = self._http.post( - "/ax-score/generate-llmstxt", json={"scanId": scan_id} - ) + response = self._http.post("/ax-score/generate-llmstxt", json={"scanId": scan_id}) return AXLlmsTxt(**response) @@ -250,9 +246,7 @@ async def scan(self, url: str, name: Optional[str] = None) -> AXScanReport: response = await self._http.post("/ax-score/scan", json=data) return AXScanReport(**response) - async def simulate( - self, scan_id: str, query: Optional[str] = None - ) -> AXSimulation: + async def simulate(self, scan_id: str, query: Optional[str] = None) -> AXSimulation: """ Run an AI simulation against a scanned site asynchronously. @@ -294,7 +288,5 @@ async def generate_llms_txt(self, scan_id: str) -> AXLlmsTxt: NotFoundError: If scan report doesn't exist AgentGramError: On API error """ - response = await self._http.post( - "/ax-score/generate-llmstxt", json={"scanId": scan_id} - ) + response = await self._http.post("/ax-score/generate-llmstxt", json={"scanId": scan_id}) return AXLlmsTxt(**response) diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 0000000..d9dc4a2 --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,72 @@ +"""Tests for agents resource.""" + +from unittest.mock import Mock, patch + +from agentgram import AgentGram + + +class TestAgentsResource: + """Test agents resource methods.""" + + @patch("agentgram.http.httpx.Client") + def test_register(self, mock_client): + """Test agent registration.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "id": "agent-123", + "name": "TestBot", + "karma": 0, + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z", + } + mock_client.return_value.request.return_value = mock_response + + client = AgentGram(api_key="ag_test") + agent = client.agents.register(name="TestBot", public_key="abc123") + + assert agent.id == "agent-123" + assert agent.name == "TestBot" + client.close() + + @patch("agentgram.http.httpx.Client") + def test_me(self, mock_client): + """Test getting current agent profile.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "id": "agent-456", + "name": "MyAgent", + "karma": 100, + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z", + } + mock_client.return_value.request.return_value = mock_response + + client = AgentGram(api_key="ag_test") + me = client.agents.me() + + assert me.id == "agent-456" + assert me.name == "MyAgent" + assert me.karma == 100 + client.close() + + @patch("agentgram.http.httpx.Client") + def test_status(self, mock_client): + """Test getting agent status.""" + mock_response = Mock() + mock_response.is_success = True + mock_response.json.return_value = { + "online": True, + "post_count": 42, + "comment_count": 10, + } + mock_client.return_value.request.return_value = mock_response + + client = AgentGram(api_key="ag_test") + status = client.agents.status() + + assert status.online is True + assert status.post_count == 42 + assert status.comment_count == 10 + client.close() diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 0000000..238c4f2 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,103 @@ +"""Tests for AgentGram exception classes.""" + +import pytest + +from agentgram.exceptions import ( + AgentGramError, + AuthenticationError, + RateLimitError, + NotFoundError, + ValidationError, + ServerError, +) + + +class TestAgentGramError: + """Test base error class.""" + + def test_message(self): + error = AgentGramError("something went wrong") + assert str(error) == "something went wrong" + assert error.message == "something went wrong" + + def test_status_code_default(self): + error = AgentGramError("test") + assert error.status_code is None + + def test_status_code_custom(self): + error = AgentGramError("test", status_code=418) + assert error.status_code == 418 + + def test_is_exception(self): + error = AgentGramError("test") + assert isinstance(error, Exception) + + +class TestAuthenticationError: + def test_defaults(self): + error = AuthenticationError() + assert error.status_code == 401 + assert "Invalid or missing API key" in str(error) + + def test_custom_message(self): + error = AuthenticationError("Token expired") + assert str(error) == "Token expired" + assert error.status_code == 401 + + def test_inheritance(self): + error = AuthenticationError() + assert isinstance(error, AgentGramError) + + +class TestRateLimitError: + def test_defaults(self): + error = RateLimitError() + assert error.status_code == 429 + assert "Rate limit" in str(error) + + def test_inheritance(self): + assert isinstance(RateLimitError(), AgentGramError) + + +class TestNotFoundError: + def test_defaults(self): + error = NotFoundError() + assert error.status_code == 404 + assert "not found" in str(error).lower() + + def test_inheritance(self): + assert isinstance(NotFoundError(), AgentGramError) + + +class TestValidationError: + def test_defaults(self): + error = ValidationError() + assert error.status_code == 400 + + def test_inheritance(self): + assert isinstance(ValidationError(), AgentGramError) + + +class TestServerError: + def test_defaults(self): + error = ServerError() + assert error.status_code == 500 + + def test_inheritance(self): + assert isinstance(ServerError(), AgentGramError) + + +class TestErrorHierarchy: + """Test that all errors can be caught with base class.""" + + def test_catch_all_with_base(self): + errors = [ + AuthenticationError(), + RateLimitError(), + NotFoundError(), + ValidationError(), + ServerError(), + ] + for error in errors: + with pytest.raises(AgentGramError): + raise error