diff --git a/nubison_model/Service.py b/nubison_model/Service.py index f435d12..dd8f509 100644 --- a/nubison_model/Service.py +++ b/nubison_model/Service.py @@ -1,16 +1,21 @@ +from contextlib import contextmanager from functools import wraps from os import environ, getenv +from tempfile import TemporaryDirectory from typing import Optional +from unittest.mock import patch import bentoml from mlflow import set_tracking_uri from mlflow.pyfunc import load_model +from starlette.testclient import TestClient from nubison_model.Model import ( DEAFULT_MLFLOW_URI, ENV_VAR_MLFLOW_MODEL_URI, ENV_VAR_MLFLOW_TRACKING_URI, ) +from nubison_model.utils import temporary_cwd def load_nubison_model( @@ -37,6 +42,20 @@ def load_nubison_model( return nubison_model +@contextmanager +def test_client(model_uri): + app = build_inference_service(mlflow_model_uri=model_uri) + # Disable metrics for testing. Avoids Prometheus client duplicated registration error + app.config["metrics"] = {"enabled": False} + + # Create a temporary directory and set it as the current working directory to run tests + # To avoid model initialization conflicts with the current directory + test_dir = TemporaryDirectory() + with temporary_cwd(test_dir.name), TestClient(app.to_asgi()) as client: + + yield client + + def build_inference_service( mlflow_tracking_uri: Optional[str] = None, mlflow_model_uri: Optional[str] = None ): diff --git a/nubison_model/__init__.py b/nubison_model/__init__.py index 2ca0ddd..5562a45 100644 --- a/nubison_model/__init__.py +++ b/nubison_model/__init__.py @@ -8,7 +8,7 @@ NubisonModel, register, ) -from .Service import build_inference_service +from .Service import build_inference_service, test_client __all__ = [ "ENV_VAR_MLFLOW_MODEL_URI", @@ -16,4 +16,5 @@ "NubisonModel", "register", "build_inference_service", + "test_client", ] diff --git a/nubison_model/utils.py b/nubison_model/utils.py new file mode 100644 index 0000000..76889ad --- /dev/null +++ b/nubison_model/utils.py @@ -0,0 +1,12 @@ +from contextlib import contextmanager +from os import chdir, getcwd + + +@contextmanager +def temporary_cwd(new_dir): + original_dir = getcwd() + try: + chdir(new_dir) + yield + finally: + chdir(original_dir) diff --git a/test/test_service.py b/test/test_service.py index 3be9f69..32fd11e 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -2,7 +2,7 @@ import pytest -from nubison_model import build_inference_service +from nubison_model import build_inference_service, test_client def test_raise_runtime_error_on_missing_env(): @@ -19,3 +19,20 @@ def infer(self, test: str): mock_load_nubison_model.return_value = DummyModel() service = build_inference_service()() assert service.infer("test") == "test" + + +def test_client_ok(): + class DummyModel: + def infer(self, test: str): + return test + + with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model: + mock_load_nubison_model.return_value = DummyModel() + with test_client("test") as client: + response = client.post("/infer", json={"test": "test"}) + assert response.status_code == 200 + assert response.text == "test" + + +# Ignore the test_client from being collected by pytest +setattr(test_client, "__test__", False) diff --git a/test/utils.py b/test/utils.py index 7e9815d..4119f2d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,7 +1,9 @@ from contextlib import contextmanager -from os import chdir, environ, getcwd, makedirs, path +from os import environ, getcwd, makedirs, path from shutil import rmtree -from typing import List, Optional +from typing import List + +from nubison_model.utils import temporary_cwd @contextmanager @@ -20,16 +22,6 @@ def temporary_dirs(dirs: List[str]): rmtree(dir) -@contextmanager -def temporary_cwd(new_dir): - original_dir = getcwd() - try: - chdir(new_dir) - yield - finally: - chdir(original_dir) - - @contextmanager def temporary_env(env: dict): original_env = environ.copy() @@ -40,3 +32,6 @@ def temporary_env(env: dict): environ.clear() environ.update(original_env) + + +__all__ = ["temporary_cwd", "temporary_dirs", "temporary_env"]