Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions backend/chainlit/data/chainlit_data_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
return PersistedUser(
id=str(row.get("id")),
identifier=str(row.get("identifier")),
createdAt=row.get("createdAt").isoformat(), # type: ignore
createdAt=row.get("createdAt").isoformat() + "Z", # type: ignore
metadata=json.loads(row.get("metadata", "{}")),
)

Expand All @@ -121,7 +121,7 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
return PersistedUser(
id=str(row.get("id")),
identifier=str(row.get("identifier")),
createdAt=row.get("createdAt").isoformat(), # type: ignore
createdAt=row.get("createdAt").isoformat() + "Z", # type: ignore
metadata=json.loads(row.get("metadata", "{}")),
)

Expand Down Expand Up @@ -497,7 +497,7 @@ async def list_threads(
for thread in threads:
thread_dict = ThreadDict(
id=str(thread["id"]),
createdAt=thread["updatedAt"].isoformat(),
createdAt=thread["updatedAt"].isoformat() + "Z",
name=thread["name"],
userId=str(thread["userId"]) if thread["userId"] else None,
userIdentifier=thread["user_identifier"],
Expand Down Expand Up @@ -561,7 +561,7 @@ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:

return ThreadDict(
id=str(thread["id"]),
createdAt=thread["createdAt"].isoformat(),
createdAt=thread["createdAt"].isoformat() + "Z",
name=thread["name"],
userId=str(thread["userId"]) if thread["userId"] else None,
userIdentifier=thread["user_identifier"],
Expand Down Expand Up @@ -681,11 +681,13 @@ def _convert_step_row_to_dict(self, row: Dict) -> StepDict:
input=row.get("input", {}),
output=row.get("output", {}),
metadata=json.loads(row.get("metadata", "{}")),
createdAt=row["createdAt"].isoformat() if row.get("createdAt") else None,
start=row["startTime"].isoformat() if row.get("startTime") else None,
createdAt=row["createdAt"].isoformat() + "Z"
if row.get("createdAt")
else None,
start=row["startTime"].isoformat() + "Z" if row.get("startTime") else None,
showInput=row.get("showInput"),
isError=row.get("isError"),
end=row["endTime"].isoformat() if row.get("endTime") else None,
end=row["endTime"].isoformat() + "Z" if row.get("endTime") else None,
feedback=self._extract_feedback_dict_from_step_row(row),
)

Expand Down
205 changes: 204 additions & 1 deletion backend/tests/data/test_chainlit_data_layer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
from datetime import datetime
from unittest.mock import AsyncMock

import pytest

from chainlit import User
from chainlit.data.chainlit_data_layer import ChainlitDataLayer
from chainlit.types import Pagination, ThreadFilter


@pytest.mark.asyncio
Expand Down Expand Up @@ -142,7 +145,207 @@ async def mock_execute_query(query, params):


@pytest.mark.asyncio
async def test_create_step_uses_nullif_for_output_and_input():
async def test_get_user_returns_iso_format_with_z_suffix():
"""Test that get_user returns createdAt with 'Z' suffix for chainlit/utils.py utc_now() compliance."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)

mock_created_at = datetime(2024, 1, 15, 10, 30, 45, 123456)

async def mock_execute_query(query, params):
if "SELECT" in query and "User" in query:
return [
{
"id": "user-123",
"identifier": "test@example.com",
"createdAt": mock_created_at,
"metadata": "{}",
}
]
return []

data_layer.execute_query = AsyncMock(side_effect=mock_execute_query)

result = await data_layer.get_user("test@example.com")

assert result is not None
assert result.id == "user-123"
assert result.identifier == "test@example.com"
assert result.createdAt == "2024-01-15T10:30:45.123456Z"


@pytest.mark.asyncio
async def test_create_user_returns_iso_format_with_z_suffix():
"""Test that create_user returns createdAt with 'Z' suffix for chainlit/utils.py utc_now() compliance."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)

mock_created_at = datetime(2024, 1, 15, 10, 30, 45, 123456)

async def mock_execute_query(query, params):
if "INSERT" in query and "User" in query:
return [
{
"id": "user-456",
"identifier": "newuser@example.com",
"createdAt": mock_created_at,
"metadata": '{"role": "admin"}',
}
]
return []

data_layer.execute_query = AsyncMock(side_effect=mock_execute_query)
data_layer.get_current_timestamp = AsyncMock(return_value=mock_created_at)

user = User(identifier="newuser@example.com", metadata={"role": "admin"})

result = await data_layer.create_user(user)

assert result is not None
assert result.id == "user-456"
assert result.identifier == "newuser@example.com"
assert result.createdAt == "2024-01-15T10:30:45.123456Z"


@pytest.mark.asyncio
async def test_list_threads_returns_iso_format_with_z_suffix():
"""Test that list_threads returns createdAt with 'Z' suffix for chainlit/utils.py utc_now() compliance."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)
mock_updated_at = datetime(2024, 2, 20, 14, 15, 30, 987654)

async def mock_execute_query(query, params):
if "SELECT" in query and "Thread" in query:
return [
{
"id": "thread-789",
"name": "Test Thread",
"userId": "user-123",
"user_identifier": "test@example.com",
"updatedAt": mock_updated_at,
"metadata": "{}",
"total": 1,
}
]
return []

data_layer.execute_query = AsyncMock(side_effect=mock_execute_query)

pagination = Pagination(first=10, cursor=None)
filters = ThreadFilter(userId=None, search=None, feedback=None)

result = await data_layer.list_threads(pagination, filters)

assert result is not None
assert len(result.data) == 1
thread = result.data[0]
assert thread["id"] == "thread-789"
assert thread["createdAt"] == "2024-02-20T14:15:30.987654Z"


@pytest.mark.asyncio
async def test_get_thread_returns_iso_format_with_z_suffix():
"""Test that get_thread returns createdAt with 'Z' suffix for chainlit/utils.py utc_now() compliance."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)
mock_created_at = datetime(2024, 3, 10, 9, 20, 15, 456789)

async def mock_execute_query(query, params):
if "SELECT t.*" in query and "Thread" in query:
return [
{
"id": "thread-101",
"name": "Single Thread",
"userId": "user-456",
"user_identifier": "user@example.com",
"createdAt": mock_created_at,
"metadata": "{}",
}
]
return []

data_layer.execute_query = AsyncMock(side_effect=mock_execute_query)

result = await data_layer.get_thread("thread-101")

assert result is not None
assert result["id"] == "thread-101"
assert result["createdAt"] == "2024-03-10T09:20:15.456789Z"


def test_convert_step_row_to_dict_returns_iso_format_with_z_suffix():
"""Test that _convert_step_row_to_dict returns timestamps with 'Z' suffix for chainlit/utils.py utc_now() compliance."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)

mock_created_at = datetime(2024, 4, 5, 12, 0, 0, 111111)
mock_start_time = datetime(2024, 4, 5, 12, 0, 5, 222222)
mock_end_time = datetime(2024, 4, 5, 12, 0, 10, 333333)

mock_row = {
"id": "step-202",
"threadId": "thread-303",
"parentId": None,
"name": "Test Step",
"type": "user_message",
"input": {"content": "Hello"},
"output": {"response": "Hi there"},
"metadata": "{}",
"createdAt": mock_created_at,
"startTime": mock_start_time,
"endTime": mock_end_time,
"showInput": "json",
"isError": False,
"feedback_id": None,
}

result = data_layer._convert_step_row_to_dict(mock_row)

assert result is not None
assert result["id"] == "step-202"
assert result["createdAt"] == "2024-04-05T12:00:00.111111Z"
assert result["start"] == "2024-04-05T12:00:05.222222Z"
assert result["end"] == "2024-04-05T12:00:10.333333Z"


def test_convert_step_row_to_dict_handles_none_timestamps():
"""Test that _convert_step_row_to_dict handles None timestamps correctly."""
data_layer = ChainlitDataLayer(
database_url="postgresql://test", storage_client=None, show_logger=False
)

mock_row = {
"id": "step-303",
"threadId": "thread-404",
"parentId": None,
"name": "Test Step",
"type": "user_message",
"input": {},
"output": {},
"metadata": "{}",
"createdAt": None,
"startTime": None,
"endTime": None,
"showInput": "json",
"isError": False,
"feedback_id": None,
}

result = data_layer._convert_step_row_to_dict(mock_row)

assert result is not None
assert result["id"] == "step-303"
assert result["createdAt"] is None
assert result["start"] is None
assert result["end"] is None


def test_create_step_uses_nullif_for_output_and_input():
"""Empty-string output/input should not overwrite existing content.

Regression test for https://github.com/Chainlit/chainlit/issues/2789
Expand Down
Loading