From a02515bbcc8ea7342a09b7c2e1762e710432156f Mon Sep 17 00:00:00 2001 From: KyuWoo Choi Date: Wed, 11 Dec 2024 10:58:59 +0900 Subject: [PATCH] fix: test_client fails due to lmodel loading wrong way (#10) --- nubison_model/Model.py | 36 ++++++++++++++------------ nubison_model/Service.py | 54 ++++++++++++--------------------------- nubison_model/__init__.py | 2 ++ test/test_integration.py | 30 ++++++++++++++++++++++ test/test_service.py | 20 +++++++++++---- 5 files changed, 83 insertions(+), 59 deletions(-) diff --git a/nubison_model/Model.py b/nubison_model/Model.py index 12e41ce..c521333 100644 --- a/nubison_model/Model.py +++ b/nubison_model/Model.py @@ -1,5 +1,5 @@ from importlib.metadata import distributions -from os import getenv, path +from os import getenv, path, symlink from sys import version_info as py_version_info from typing import Any, List, Optional, Protocol, runtime_checkable @@ -25,20 +25,20 @@ class NubisonMLFlowModel(PythonModel): def __init__(self, nubison_model: NubisonModel): self._nubison_model = nubison_model - def _check_artifacts_prepared(self) -> bool: + def _check_artifacts_prepared(self, artifacts: dict) -> bool: """Check if all symlinks for the artifacts are created successfully.""" - for name, target_path in self._artifacts.items(): + for name, target_path in artifacts.items(): if not path.exists(name): - print(f"Symlink for {name} was not created successfully.") return False - return True - def prepare_artifacts(self) -> None: - """Create symbolic links for the artifacts stored in the _artifacts attribute.""" - from os import path, symlink + def prepare_artifacts(self, artifacts: dict) -> None: + """Create symbolic links for the artifacts provided as a parameter.""" + if self._check_artifacts_prepared(artifacts): + print("Skipping artifact preparation as it was already done.") + return - for name, target_path in self._artifacts.items(): + for name, target_path in artifacts.items(): try: symlink(target_path, name, target_is_directory=path.isdir(target_path)) print(f"Prepared artifact: {name} -> {target_path}") @@ -51,14 +51,7 @@ def load_context(self, context: Any) -> None: Args: context (PythonModelContext): A collection of artifacts that a PythonModel can use when performing inference. """ - # Check if symlinks are made and proceed if all symlinks are ok - self._artifacts = context.artifacts - - if not self._check_artifacts_prepared(): - print("Artifacts were not prepared. Skipping model loading.") - return - - self._nubison_model.load_model() + self.prepare_artifacts(context.artifacts) def predict(self, context, model_input): input = model_input["input"] @@ -67,6 +60,15 @@ def predict(self, context, model_input): def get_nubison_model(self): return self._nubison_model + def load_model(self): + self._nubison_model.load_model() + + def infer(self, *args, **kwargs) -> Any: + return self._nubison_model.infer(*args, **kwargs) + + def get_nubison_model_infer_method(self): + return self._nubison_model.__class__.infer + def _is_shareable(package: str) -> bool: # Nested requirements, constraints files, local packages, and comments are not supported diff --git a/nubison_model/Service.py b/nubison_model/Service.py index aa85517..89dc7d0 100644 --- a/nubison_model/Service.py +++ b/nubison_model/Service.py @@ -19,48 +19,39 @@ from nubison_model.utils import temporary_cwd -def load_nubison_model( - mlflow_tracking_uri, - mlflow_model_uri, - prepare_artifacts: bool = False, -): +def load_nubison_mlflow_model(mlflow_tracking_uri, mlflow_model_uri): + if not mlflow_tracking_uri: + raise RuntimeError("MLflow tracking URI is not set") + if not mlflow_model_uri: + raise RuntimeError("MLflow model URI is not set") try: - if not mlflow_tracking_uri: - raise RuntimeError("MLflow tracking URI is not set") - if not mlflow_model_uri: - raise RuntimeError("MLflow model URI is not set") set_tracking_uri(mlflow_tracking_uri) mlflow_model = load_model(model_uri=mlflow_model_uri) - nubison_mlflow_model = cast( NubisonMLFlowModel, mlflow_model.unwrap_python_model() ) - if prepare_artifacts: - nubison_mlflow_model.prepare_artifacts() - - # Get the NubisonModel instance from the MLflow model - nubison_model = nubison_mlflow_model.get_nubison_model() except Exception as e: raise RuntimeError( f"Error loading model(uri: {mlflow_model_uri}) from model registry(uri: {mlflow_tracking_uri})" ) from e - return nubison_model + return nubison_mlflow_model @contextmanager def test_client(model_uri): - app = build_inference_service(mlflow_model_uri=model_uri) - # Disable metrics for testing. Avoids Prometheus client duplicated registration error - app.config["metrics"] = {"enabled": False} # Create a temporary directory and set it as the current working directory to run tests # To avoid model initialization conflicts with the current directory test_dir = TemporaryDirectory() - with temporary_cwd(test_dir.name), TestClient(app.to_asgi()) as client: + with temporary_cwd(test_dir.name): + app = build_inference_service(mlflow_model_uri=model_uri) + # Disable metrics for testing. Avoids Prometheus client duplicated registration error + app.config["metrics"] = {"enabled": False} - yield client + with TestClient(app.to_asgi()) as client: + yield client def build_inference_service( @@ -71,18 +62,15 @@ def build_inference_service( ) mlflow_model_uri = mlflow_model_uri or getenv(ENV_VAR_MLFLOW_MODEL_URI) or "" - nubison_model_class = load_nubison_model( + nubison_mlflow_model = load_nubison_mlflow_model( mlflow_tracking_uri=mlflow_tracking_uri, mlflow_model_uri=mlflow_model_uri, - prepare_artifacts=True, - ).__class__ + ) @bentoml.service class BentoMLService: """BentoML Service for serving machine learning models.""" - _nubison_model = None - def __init__(self): """Initializes the BentoML Service for serving machine learning models. @@ -92,14 +80,10 @@ def __init__(self): Raises: RuntimeError: Error loading model from the model registry """ - self._nubison_model = load_nubison_model( - mlflow_tracking_uri=mlflow_tracking_uri, - mlflow_model_uri=mlflow_model_uri, - prepare_artifacts=False, - ) + nubison_mlflow_model.load_model() @bentoml.api - @wraps(nubison_model_class.infer) + @wraps(nubison_mlflow_model.get_nubison_model_infer_method()) def infer(self, *args, **kwargs): """Proxy method to the NubisonModel.infer method @@ -109,11 +93,7 @@ def infer(self, *args, **kwargs): Returns: _type_: The return type of the NubisonModel.infer method """ - - if self._nubison_model is None: - raise RuntimeError("Model is not loaded") - - return self._nubison_model.infer(*args, **kwargs) + return nubison_mlflow_model.infer(*args, **kwargs) return BentoMLService diff --git a/nubison_model/__init__.py b/nubison_model/__init__.py index 5562a45..1f6c4df 100644 --- a/nubison_model/__init__.py +++ b/nubison_model/__init__.py @@ -5,6 +5,7 @@ from .Model import ( ENV_VAR_MLFLOW_MODEL_URI, ENV_VAR_MLFLOW_TRACKING_URI, + NubisonMLFlowModel, NubisonModel, register, ) @@ -14,6 +15,7 @@ "ENV_VAR_MLFLOW_MODEL_URI", "ENV_VAR_MLFLOW_TRACKING_URI", "NubisonModel", + "NubisonMLFlowModel", "register", "build_inference_service", "test_client", diff --git a/test/test_integration.py b/test/test_integration.py index d04a6da..62c6826 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -1,6 +1,7 @@ from nubison_model import ( ENV_VAR_MLFLOW_MODEL_URI, NubisonModel, + Service, build_inference_service, register, ) @@ -35,3 +36,32 @@ def infer(self, param1: str): ): bento_service = build_inference_service()() assert bento_service.infer("test") == "bartest" + + +def test_register_and_test_model(mlflow_server): + """ + Test registering a model to MLflow's Model Registry and testing it with BentoML. + """ + + class DummyModel(NubisonModel): + def load_model(self): + # Try to read the contents of the artifact file + with open("./fixtures/bar.txt", "r") as f: + self.loaded = f.read() + + def infer(self, param1: str): + # Try to import a function from the artifact code + from .fixtures.poo import echo + + return echo(self.loaded + param1) + + # Switch cwd to the current file directory to register the fixture artifact + with temporary_cwd("test"): + model_uri = register(DummyModel(), artifact_dirs="fixtures") + + # Create temp dir and switch to it to test the model. + # So artifact symlink not to coliide with the current directory + with Service.test_client(model_uri) as client: + response = client.post("/infer", json={"param1": "test"}) + assert response.status_code == 200 + assert response.text == "bartest" diff --git a/test/test_service.py b/test/test_service.py index 32fd11e..713016c 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -2,7 +2,7 @@ import pytest -from nubison_model import build_inference_service, test_client +from nubison_model import NubisonMLFlowModel, build_inference_service, test_client def test_raise_runtime_error_on_missing_env(): @@ -15,8 +15,13 @@ class DummyModel: def infer(self, test: str): return test - with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model: - mock_load_nubison_model.return_value = DummyModel() + def load_model(self): + pass + + with patch( + "nubison_model.Service.load_nubison_mlflow_model" + ) as mock_load_nubison_mlflow_model: + mock_load_nubison_mlflow_model.return_value = NubisonMLFlowModel(DummyModel()) service = build_inference_service()() assert service.infer("test") == "test" @@ -26,8 +31,13 @@ class DummyModel: def infer(self, test: str): return test - with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model: - mock_load_nubison_model.return_value = DummyModel() + def load_model(self): + pass + + with patch( + "nubison_model.Service.load_nubison_mlflow_model" + ) as mock_load_nubison_mlflow_model: + mock_load_nubison_mlflow_model.return_value = NubisonMLFlowModel(DummyModel()) with test_client("test") as client: response = client.post("/infer", json={"test": "test"}) assert response.status_code == 200