diff --git a/backend/chainlit/element.py b/backend/chainlit/element.py index 23a8b41a4a..785d609b5f 100644 --- a/backend/chainlit/element.py +++ b/backend/chainlit/element.py @@ -64,6 +64,7 @@ class ElementDict(TypedDict, total=False): playerConfig: Optional[dict] forId: Optional[str] mime: Optional[str] + downloadName: Optional[str] @dataclass @@ -96,6 +97,8 @@ class Element: language: Optional[str] = None # Mime type, inferred based on content if not provided mime: Optional[str] = None + # Custom download filename. If set, this name is used when the file is downloaded. + download_name: Optional[str] = None def __post_init__(self) -> None: self.persisted = False @@ -123,6 +126,7 @@ def to_dict(self) -> ElementDict: "language": getattr(self, "language", None), "forId": getattr(self, "for_id", None), "mime": getattr(self, "mime", None), + "downloadName": getattr(self, "download_name", None), } ) return _dict @@ -149,6 +153,7 @@ def from_dict(cls, e_dict: ElementDict): chainlit_key = e_dict.get("chainlitKey") display = e_dict.get("display", "inline") mime_type = e_dict.get("mime", "") + download_name = e_dict.get("downloadName") # Common parameters for all element types common_params = { @@ -162,6 +167,7 @@ def from_dict(cls, e_dict: ElementDict): "chainlit_key": chainlit_key, "display": display, "mime": mime_type, + "download_name": download_name, } if type == "image": @@ -218,6 +224,7 @@ async def _create(self, persist=True) -> bool: path=self.path, content=self.content, mime=self.mime or "", + download_name=self.download_name, ) self.chainlit_key = file_dict["id"] diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index 6592af08fd..d414e4ae6b 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -1658,7 +1658,13 @@ async def get_file( if file_id in session.files: file = session.files[file_id] - return FileResponse(file["path"], media_type=file["type"]) + filename = file.get("download_name") or file.get("name") + return FileResponse( + file["path"], + media_type=file["type"], + filename=filename, + content_disposition_type="inline", + ) else: raise HTTPException(status_code=404, detail="File not found") diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index d6bd3f6214..e30c0c6fe3 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -100,6 +100,7 @@ async def persist_file( mime: str, path: Optional[str] = None, content: Optional[Union[bytes, str]] = None, + download_name: Optional[str] = None, ) -> FileReference: if not path and not content: raise ValueError( @@ -140,6 +141,7 @@ async def persist_file( "name": name, "type": mime, "size": file_size, + "download_name": download_name, } return {"id": file_id} diff --git a/backend/chainlit/types.py b/backend/chainlit/types.py index 2adcc8ad4e..82a39a84d8 100644 --- a/backend/chainlit/types.py +++ b/backend/chainlit/types.py @@ -167,6 +167,7 @@ class FileDict(TypedDict): path: Path size: int type: str + download_name: Optional[str] class MessagePayload(TypedDict):