Skip to content

Commit

Permalink
feat: add test client for easy test against server (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyuwoo-choi authored Nov 13, 2024
1 parent e6075f4 commit 587667f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 14 deletions.
19 changes: 19 additions & 0 deletions nubison_model/Service.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from contextlib import contextmanager
from functools import wraps
from os import environ, getenv
from tempfile import TemporaryDirectory
from typing import Optional
from unittest.mock import patch

import bentoml
from mlflow import set_tracking_uri
from mlflow.pyfunc import load_model
from starlette.testclient import TestClient

from nubison_model.Model import (
DEAFULT_MLFLOW_URI,
ENV_VAR_MLFLOW_MODEL_URI,
ENV_VAR_MLFLOW_TRACKING_URI,
)
from nubison_model.utils import temporary_cwd


def load_nubison_model(
Expand All @@ -37,6 +42,20 @@ def load_nubison_model(
return nubison_model


@contextmanager
def test_client(model_uri):
app = build_inference_service(mlflow_model_uri=model_uri)
# Disable metrics for testing. Avoids Prometheus client duplicated registration error
app.config["metrics"] = {"enabled": False}

# Create a temporary directory and set it as the current working directory to run tests
# To avoid model initialization conflicts with the current directory
test_dir = TemporaryDirectory()
with temporary_cwd(test_dir.name), TestClient(app.to_asgi()) as client:

yield client


def build_inference_service(
mlflow_tracking_uri: Optional[str] = None, mlflow_model_uri: Optional[str] = None
):
Expand Down
3 changes: 2 additions & 1 deletion nubison_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
NubisonModel,
register,
)
from .Service import build_inference_service
from .Service import build_inference_service, test_client

__all__ = [
"ENV_VAR_MLFLOW_MODEL_URI",
"ENV_VAR_MLFLOW_TRACKING_URI",
"NubisonModel",
"register",
"build_inference_service",
"test_client",
]
12 changes: 12 additions & 0 deletions nubison_model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from contextlib import contextmanager
from os import chdir, getcwd


@contextmanager
def temporary_cwd(new_dir):
original_dir = getcwd()
try:
chdir(new_dir)
yield
finally:
chdir(original_dir)
19 changes: 18 additions & 1 deletion test/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from nubison_model import build_inference_service
from nubison_model import build_inference_service, test_client


def test_raise_runtime_error_on_missing_env():
Expand All @@ -19,3 +19,20 @@ def infer(self, test: str):
mock_load_nubison_model.return_value = DummyModel()
service = build_inference_service()()
assert service.infer("test") == "test"


def test_client_ok():
class DummyModel:
def infer(self, test: str):
return test

with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model:
mock_load_nubison_model.return_value = DummyModel()
with test_client("test") as client:
response = client.post("/infer", json={"test": "test"})
assert response.status_code == 200
assert response.text == "test"


# Ignore the test_client from being collected by pytest
setattr(test_client, "__test__", False)
19 changes: 7 additions & 12 deletions test/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import contextmanager
from os import chdir, environ, getcwd, makedirs, path
from os import environ, getcwd, makedirs, path
from shutil import rmtree
from typing import List, Optional
from typing import List

from nubison_model.utils import temporary_cwd


@contextmanager
Expand All @@ -20,16 +22,6 @@ def temporary_dirs(dirs: List[str]):
rmtree(dir)


@contextmanager
def temporary_cwd(new_dir):
original_dir = getcwd()
try:
chdir(new_dir)
yield
finally:
chdir(original_dir)


@contextmanager
def temporary_env(env: dict):
original_env = environ.copy()
Expand All @@ -40,3 +32,6 @@ def temporary_env(env: dict):

environ.clear()
environ.update(original_env)


__all__ = ["temporary_cwd", "temporary_dirs", "temporary_env"]

0 comments on commit 587667f

Please sign in to comment.