diff --git a/langserve/playground.py b/langserve/playground.py index 66ca639c..c5f915be 100644 --- a/langserve/playground.py +++ b/langserve/playground.py @@ -14,6 +14,39 @@ class PlaygroundTemplate(Template): delimiter = "____" +def _get_mimetype(path: str) -> str: + """Get mimetype for file. + + Custom implementation of mimetypes.guess_type that + uses the file extension to determine the mimetype for some files. + + This is necessary due to: https://bugs.python.org/issue43975 + Resolves issue: https://github.com/langchain-ai/langserve/issues/245 + + Args: + path (str): Path to file + + Returns: + str: Mimetype of file + """ + try: + file_extension = path.lower().split(".")[-1] + except IndexError: + return mimetypes.guess_type(path)[0] + + if file_extension == "js": + return "application/javascript" + elif file_extension == "css": + return "text/css" + elif file_extension in ["htm", "html"]: + return "text/html" + + # If the file extension is not one of the specified ones, + # use the default guess method + mime_type = mimetypes.guess_type(path)[0] + return mime_type + + async def serve_playground( runnable: Runnable, input_schema: Type[BaseModel], @@ -39,7 +72,7 @@ async def serve_playground( try: with open(local_file_path, encoding="utf-8") as f: - mime_type = mimetypes.guess_type(local_file_path)[0] + mime_type = _get_mimetype(local_file_path) if mime_type in ("text/html", "text/css", "application/javascript"): response = PlaygroundTemplate(f.read()).substitute( LANGSERVE_BASE_URL=base_url[1:] diff --git a/tests/unit_tests/test_playground.py b/tests/unit_tests/test_playground.py new file mode 100644 index 00000000..fa49d8d6 --- /dev/null +++ b/tests/unit_tests/test_playground.py @@ -0,0 +1,24 @@ +import pytest + +from langserve.playground import _get_mimetype + + +@pytest.mark.parametrize( + "file_extension, expected_mimetype", + [ + ("js", "application/javascript"), + ("css", "text/css"), + ("htm", "text/html"), + ("html", "text/html"), + ("txt", "text/plain"), # An example of an unknown extension using guess_type + ], +) +def test_get_mimetype(file_extension: str, expected_mimetype: str) -> None: + # Create a filename with the given extension + filename = f"test_file.{file_extension}" + + # Call the _get_mimetype function with the test filename + mimetype = _get_mimetype(filename) + + # Check if the returned mimetype matches the expected one + assert mimetype == expected_mimetype