Skip to content

Commit

Permalink
fix: test_client fails due to lmodel loading wrong way (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyuwoo-choi authored Dec 11, 2024
1 parent fde4981 commit a02515b
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 59 deletions.
36 changes: 19 additions & 17 deletions nubison_model/Model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from importlib.metadata import distributions
from os import getenv, path
from os import getenv, path, symlink
from sys import version_info as py_version_info
from typing import Any, List, Optional, Protocol, runtime_checkable

Expand All @@ -25,20 +25,20 @@ class NubisonMLFlowModel(PythonModel):
def __init__(self, nubison_model: NubisonModel):
self._nubison_model = nubison_model

def _check_artifacts_prepared(self) -> bool:
def _check_artifacts_prepared(self, artifacts: dict) -> bool:
"""Check if all symlinks for the artifacts are created successfully."""
for name, target_path in self._artifacts.items():
for name, target_path in artifacts.items():
if not path.exists(name):
print(f"Symlink for {name} was not created successfully.")
return False

return True

def prepare_artifacts(self) -> None:
"""Create symbolic links for the artifacts stored in the _artifacts attribute."""
from os import path, symlink
def prepare_artifacts(self, artifacts: dict) -> None:
"""Create symbolic links for the artifacts provided as a parameter."""
if self._check_artifacts_prepared(artifacts):
print("Skipping artifact preparation as it was already done.")
return

for name, target_path in self._artifacts.items():
for name, target_path in artifacts.items():
try:
symlink(target_path, name, target_is_directory=path.isdir(target_path))
print(f"Prepared artifact: {name} -> {target_path}")
Expand All @@ -51,14 +51,7 @@ def load_context(self, context: Any) -> None:
Args:
context (PythonModelContext): A collection of artifacts that a PythonModel can use when performing inference.
"""
# Check if symlinks are made and proceed if all symlinks are ok
self._artifacts = context.artifacts

if not self._check_artifacts_prepared():
print("Artifacts were not prepared. Skipping model loading.")
return

self._nubison_model.load_model()
self.prepare_artifacts(context.artifacts)

def predict(self, context, model_input):
input = model_input["input"]
Expand All @@ -67,6 +60,15 @@ def predict(self, context, model_input):
def get_nubison_model(self):
return self._nubison_model

def load_model(self):
self._nubison_model.load_model()

def infer(self, *args, **kwargs) -> Any:
return self._nubison_model.infer(*args, **kwargs)

def get_nubison_model_infer_method(self):
return self._nubison_model.__class__.infer


def _is_shareable(package: str) -> bool:
# Nested requirements, constraints files, local packages, and comments are not supported
Expand Down
54 changes: 17 additions & 37 deletions nubison_model/Service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,39 @@
from nubison_model.utils import temporary_cwd


def load_nubison_model(
mlflow_tracking_uri,
mlflow_model_uri,
prepare_artifacts: bool = False,
):
def load_nubison_mlflow_model(mlflow_tracking_uri, mlflow_model_uri):
if not mlflow_tracking_uri:
raise RuntimeError("MLflow tracking URI is not set")
if not mlflow_model_uri:
raise RuntimeError("MLflow model URI is not set")

try:
if not mlflow_tracking_uri:
raise RuntimeError("MLflow tracking URI is not set")
if not mlflow_model_uri:
raise RuntimeError("MLflow model URI is not set")
set_tracking_uri(mlflow_tracking_uri)
mlflow_model = load_model(model_uri=mlflow_model_uri)

nubison_mlflow_model = cast(
NubisonMLFlowModel, mlflow_model.unwrap_python_model()
)
if prepare_artifacts:
nubison_mlflow_model.prepare_artifacts()

# Get the NubisonModel instance from the MLflow model
nubison_model = nubison_mlflow_model.get_nubison_model()
except Exception as e:
raise RuntimeError(
f"Error loading model(uri: {mlflow_model_uri}) from model registry(uri: {mlflow_tracking_uri})"
) from e

return nubison_model
return nubison_mlflow_model


@contextmanager
def test_client(model_uri):
app = build_inference_service(mlflow_model_uri=model_uri)
# Disable metrics for testing. Avoids Prometheus client duplicated registration error
app.config["metrics"] = {"enabled": False}

# Create a temporary directory and set it as the current working directory to run tests
# To avoid model initialization conflicts with the current directory
test_dir = TemporaryDirectory()
with temporary_cwd(test_dir.name), TestClient(app.to_asgi()) as client:
with temporary_cwd(test_dir.name):
app = build_inference_service(mlflow_model_uri=model_uri)
# Disable metrics for testing. Avoids Prometheus client duplicated registration error
app.config["metrics"] = {"enabled": False}

yield client
with TestClient(app.to_asgi()) as client:
yield client


def build_inference_service(
Expand All @@ -71,18 +62,15 @@ def build_inference_service(
)
mlflow_model_uri = mlflow_model_uri or getenv(ENV_VAR_MLFLOW_MODEL_URI) or ""

nubison_model_class = load_nubison_model(
nubison_mlflow_model = load_nubison_mlflow_model(
mlflow_tracking_uri=mlflow_tracking_uri,
mlflow_model_uri=mlflow_model_uri,
prepare_artifacts=True,
).__class__
)

@bentoml.service
class BentoMLService:
"""BentoML Service for serving machine learning models."""

_nubison_model = None

def __init__(self):
"""Initializes the BentoML Service for serving machine learning models.
Expand All @@ -92,14 +80,10 @@ def __init__(self):
Raises:
RuntimeError: Error loading model from the model registry
"""
self._nubison_model = load_nubison_model(
mlflow_tracking_uri=mlflow_tracking_uri,
mlflow_model_uri=mlflow_model_uri,
prepare_artifacts=False,
)
nubison_mlflow_model.load_model()

@bentoml.api
@wraps(nubison_model_class.infer)
@wraps(nubison_mlflow_model.get_nubison_model_infer_method())
def infer(self, *args, **kwargs):
"""Proxy method to the NubisonModel.infer method
Expand All @@ -109,11 +93,7 @@ def infer(self, *args, **kwargs):
Returns:
_type_: The return type of the NubisonModel.infer method
"""

if self._nubison_model is None:
raise RuntimeError("Model is not loaded")

return self._nubison_model.infer(*args, **kwargs)
return nubison_mlflow_model.infer(*args, **kwargs)

return BentoMLService

Expand Down
2 changes: 2 additions & 0 deletions nubison_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .Model import (
ENV_VAR_MLFLOW_MODEL_URI,
ENV_VAR_MLFLOW_TRACKING_URI,
NubisonMLFlowModel,
NubisonModel,
register,
)
Expand All @@ -14,6 +15,7 @@
"ENV_VAR_MLFLOW_MODEL_URI",
"ENV_VAR_MLFLOW_TRACKING_URI",
"NubisonModel",
"NubisonMLFlowModel",
"register",
"build_inference_service",
"test_client",
Expand Down
30 changes: 30 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from nubison_model import (
ENV_VAR_MLFLOW_MODEL_URI,
NubisonModel,
Service,
build_inference_service,
register,
)
Expand Down Expand Up @@ -35,3 +36,32 @@ def infer(self, param1: str):
):
bento_service = build_inference_service()()
assert bento_service.infer("test") == "bartest"


def test_register_and_test_model(mlflow_server):
"""
Test registering a model to MLflow's Model Registry and testing it with BentoML.
"""

class DummyModel(NubisonModel):
def load_model(self):
# Try to read the contents of the artifact file
with open("./fixtures/bar.txt", "r") as f:
self.loaded = f.read()

def infer(self, param1: str):
# Try to import a function from the artifact code
from .fixtures.poo import echo

return echo(self.loaded + param1)

# Switch cwd to the current file directory to register the fixture artifact
with temporary_cwd("test"):
model_uri = register(DummyModel(), artifact_dirs="fixtures")

# Create temp dir and switch to it to test the model.
# So artifact symlink not to coliide with the current directory
with Service.test_client(model_uri) as client:
response = client.post("/infer", json={"param1": "test"})
assert response.status_code == 200
assert response.text == "bartest"
20 changes: 15 additions & 5 deletions test/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from nubison_model import build_inference_service, test_client
from nubison_model import NubisonMLFlowModel, build_inference_service, test_client


def test_raise_runtime_error_on_missing_env():
Expand All @@ -15,8 +15,13 @@ class DummyModel:
def infer(self, test: str):
return test

with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model:
mock_load_nubison_model.return_value = DummyModel()
def load_model(self):
pass

with patch(
"nubison_model.Service.load_nubison_mlflow_model"
) as mock_load_nubison_mlflow_model:
mock_load_nubison_mlflow_model.return_value = NubisonMLFlowModel(DummyModel())
service = build_inference_service()()
assert service.infer("test") == "test"

Expand All @@ -26,8 +31,13 @@ class DummyModel:
def infer(self, test: str):
return test

with patch("nubison_model.Service.load_nubison_model") as mock_load_nubison_model:
mock_load_nubison_model.return_value = DummyModel()
def load_model(self):
pass

with patch(
"nubison_model.Service.load_nubison_mlflow_model"
) as mock_load_nubison_mlflow_model:
mock_load_nubison_mlflow_model.return_value = NubisonMLFlowModel(DummyModel())
with test_client("test") as client:
response = client.post("/infer", json={"test": "test"})
assert response.status_code == 200
Expand Down

0 comments on commit a02515b

Please sign in to comment.