Skip to content

Commit

Permalink
feat: add context support for parallel inference (#12)
Browse files Browse the repository at this point in the history
* feat: enhance model loading with context support

- Updated `load_model` method in `UserModel` and `NubisonModel` to accept a `ModelContext` argument, allowing for better handling of worker-specific information during model loading.
- Introduced `ModelContext` type definition to encapsulate worker index and total number of workers for GPU initialization in parallel setups.
- Adjusted related code in service and tests to accommodate the new context parameter.
- Updated documentation in `README.md` to reflect changes in the `load_model` method and its parameters.

* feat: update Dockerfile to support parallel inference

- Added Open Container Initiative (OCI) labels for better image description and source tracking.
- Introduced a new environment variable `NUM_WORKERS` with a default value of 4 to configure the number of workers for the application.

* refactor: reduce default worker count for improved resource management

- Updated the Dockerfile to change the `NUM_WORKERS` environment variable from 4 to 2, optimizing resource allocation.
- Adjusted the default number of workers in `Service.py` from 4 to 1 to align with the new Docker configuration, enhancing performance and efficiency during model loading.
  • Loading branch information
kyuwoo-choi authored Jan 2, 2025
1 parent 9b25ba4 commit 8f73ebb
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 44 deletions.
7 changes: 6 additions & 1 deletion container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ SHELL ["bash", "-c"]
# Set the working directory
WORKDIR /app

# Copy the entrypoint script into the container
# Add label descriptions
LABEL org.opencontainers.image.title="nubison-model" \
org.opencontainers.image.description="A container image for nubison-model."\
org.opencontainers.image.source="https://github.com/nubison/nubison-model"

COPY start_server.sh /app/start_server.sh
RUN chmod +x /app/start_server.sh

Expand All @@ -22,6 +26,7 @@ ENV MLFLOW_TRACKING_URI=""
ENV MLFLOW_MODEL_URI=""
ENV DEBUG=""
ENV PORT=3000
ENV NUM_WORKERS=2

# Expose the port specified by the PORT environment variable
EXPOSE ${PORT}
Expand Down
4 changes: 4 additions & 0 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ The `model.ipynb` file shows how to register a user model. It contains the follo

- Use this method to prepare the model for inference which can be time-consuming.
- This method is called once when the model inference server starts.
- The `load_model` method receives a `ModelContext` dictionary containing:
- `worker_index`: Index of the worker process (0-based) for parallel processing
- `num_workers`: Total number of workers running the model
- This information is particularly useful for GPU initialization in parallel setups, where you can map specific workers to specific GPU devices.
- The path to the model weights file can be specified relative.

#### #### `infer` method
Expand Down
50 changes: 24 additions & 26 deletions example/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
"# The `NubisonModel` class serves as a base class for creating custom user model classes.\n",
"# Note that modules required by UserModel must be imported within the NubisonModel class.\n",
"# This is because the UserModel is cloudpickled, and using modules imported outside of the NubisonModel class will cause errors.\n",
"from nubison_model import NubisonModel\n",
"from nubison_model import NubisonModel, ModelContext\n",
"\n",
"class UserModel(NubisonModel):\n",
" \"\"\"\n",
" A user model that extends the `NubisonModel` base class.\n",
" \"\"\"\n",
" def load_model(self) -> None:\n",
" \"\"\"\n",
" This method is used to load the model weights from the file.\n",
" \"\"\"A user model that extends the NubisonModel base class.\"\"\"\n",
" \n",
" def load_model(self, context: ModelContext) -> None:\n",
" \"\"\"Load the model weights from the file.\n",
" \n",
" Args:\n",
" context: Contains worker_index (0-based) for GPU initialization in parallel setups.\n",
" \"\"\"\n",
" try:\n",
" # Import the SimpleLinearModel class from the src directory\n",
Expand All @@ -51,29 +52,29 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/28 16:58:19 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
"2025/01/02 15:54:12 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n",
"2024/11/28 16:58:19 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n",
"Registered model 'nubison_model' already exists. Creating a new version of this model...\n",
"2024/11/28 16:58:20 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: nubison_model, version 8\n",
"Created version '8' of model 'nubison_model'.\n",
"2024/11/28 16:58:20 INFO mlflow.tracking._tracking_service.client: 🏃 View run bustling-mare-593 at: https://model.nubison.io/#/experiments/1207/runs/dbebfd0f99594a0fa1c67a6a00b3e270.\n",
"2024/11/28 16:58:20 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: https://model.nubison.io/#/experiments/1207.\n"
"2025/01/02 15:54:12 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n",
"Registered model 'Default' already exists. Creating a new version of this model...\n",
"2025/01/02 15:54:12 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: Default, version 144\n",
"Created version '144' of model 'Default'.\n",
"2025/01/02 15:54:12 INFO mlflow.tracking._tracking_service.client: 🏃 View run rumbling-bass-282 at: http://127.0.0.1:5000/#/experiments/0/runs/ee369aafa91c4753b7bb067acf466a9c.\n",
"2025/01/02 15:54:12 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/0.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model registered: runs:/dbebfd0f99594a0fa1c67a6a00b3e270/\n"
"Model registered: runs:/ee369aafa91c4753b7bb067acf466a9c/\n"
]
}
],
Expand All @@ -83,7 +84,7 @@
"\n",
"# Register the user model\n",
"# The `artifact_dirs` argument specifies the folders containing the files used by the model class.\n",
"model_id = register(UserModel(), mlflow_uri=\"https://model.nubison.io\", model_name=\"nubison_model\", artifact_dirs=\"src\", params={\"desc\": \"This is a test model\"}, metrics={\"train\": 0.9, \"validation\": 0.8, \"test\": 0.7})\n",
"model_id = register(UserModel(), artifact_dirs=\"src\", params={\"desc\": \"This is a test model\"}, metrics={\"train\": 0.9, \"validation\": 0.8, \"test\": 0.7})\n",
"print(f\"Model registered: {model_id}\")\n"
]
},
Expand All @@ -96,14 +97,14 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
"2025/01/02 15:54:15 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n"
]
Expand All @@ -112,20 +113,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n",
"2024-11-28 14:29:26,359 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n",
"2025-01-02 15:54:15,604 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n",
"INFO:SimpleLinearModel:Weights loaded successfully from ./src/weights.txt.\n",
"2024-11-28 14:29:26,367 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n",
"2025-01-02 15:54:15,621 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n",
"INFO:SimpleLinearModel:Calculating the result of the linear model with x1=3.1, x2=2.0.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prepared artifact: src -> /tmp/tmprsct2rwt/artifacts/src\n",
"Prepared artifact: src -> /tmp/tmpamgz0yua/artifacts/src\n",
"The result of the linear model is 4.35.\n"
]
}
Expand All @@ -152,7 +150,7 @@
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "testmodel",
"language": "python",
"name": "python3"
},
Expand Down
60 changes: 55 additions & 5 deletions nubison_model/Model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from importlib.metadata import distributions
from os import getenv, path, symlink
from sys import version_info as py_version_info
from typing import Any, List, Optional, Protocol, runtime_checkable
from typing import Any, List, Optional, Protocol, TypedDict, runtime_checkable

import mlflow
from mlflow.models.model import ModelInfo
Expand All @@ -14,11 +14,61 @@
DEFAULT_ARTIFACT_DIRS = "" # Default code paths comma-separated


class ModelContext(TypedDict):
"""Context information passed to model during loading.
Attributes:
worker_index: Index of the worker process running the model. Used to identify
which worker is running the model in a parallel server setup. Starts from 0.
Even in a single server process setup, this will be 0. This is particularly
useful for GPU initialization as you can map specific workers to specific
GPU devices.
num_workers: Number of workers running the model.
"""

worker_index: int

num_workers: int


@runtime_checkable
class NubisonModel(Protocol):
def load_model(self) -> None: ...
"""Protocol defining the interface for user-defined models.
Your model class must implement this protocol by providing:
1. load_model method - Called once at startup to initialize the model
2. infer method - Called for each inference request
"""

def load_model(self, context: ModelContext) -> None:
"""Initialize and load the model.
This method is called once when the model server starts up.
Use it to load model weights and initialize any resources needed for inference.
def infer(self, input: Any) -> Any: ...
Args:
context: A dictionary containing worker information:
- worker_index: Index of the worker process (0-based)
- num_workers: Total number of workers running the model
This information is particularly useful for GPU initialization
in parallel setups, where you can map specific workers to
specific GPU devices.
"""
...

def infer(self, input: Any) -> Any:
"""Perform inference on the input.
This method is called for each inference request.
Args:
input: The input data to perform inference on.
Can be of any type that your model accepts.
Returns:
The inference result. Can be of any type that your model produces.
"""
...


class NubisonMLFlowModel(PythonModel):
Expand Down Expand Up @@ -60,8 +110,8 @@ def predict(self, context, model_input):
def get_nubison_model(self):
return self._nubison_model

def load_model(self):
self._nubison_model.load_model()
def load_model(self, context: ModelContext):
self._nubison_model.load_model(context)

def infer(self, *args, **kwargs) -> Any:
return self._nubison_model.infer(*args, **kwargs)
Expand Down
22 changes: 20 additions & 2 deletions nubison_model/Service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
)
from nubison_model.utils import temporary_cwd

ENV_VAR_NUM_WORKERS = "NUM_WORKERS"
DEFAULT_NUM_WORKERS = 1


def load_nubison_mlflow_model(mlflow_tracking_uri, mlflow_model_uri):
if not mlflow_tracking_uri:
Expand Down Expand Up @@ -63,12 +66,14 @@ def build_inference_service(
)
mlflow_model_uri = mlflow_model_uri or getenv(ENV_VAR_MLFLOW_MODEL_URI) or ""

num_workers = int(getenv(ENV_VAR_NUM_WORKERS) or DEFAULT_NUM_WORKERS)

nubison_mlflow_model = load_nubison_mlflow_model(
mlflow_tracking_uri=mlflow_tracking_uri,
mlflow_model_uri=mlflow_model_uri,
)

@bentoml.service
@bentoml.service(workers=num_workers)
class BentoMLService:
"""BentoML Service for serving machine learning models."""

Expand All @@ -81,7 +86,20 @@ def __init__(self):
Raises:
RuntimeError: Error loading model from the model registry
"""
nubison_mlflow_model.load_model()

# Set default worker index to 1 in case of no bentoml server context is available
# For example, when running with test client
context = {
"worker_index": 0,
"num_workers": 1,
}
if bentoml.server_context.worker_index is not None:
context = {
"worker_index": bentoml.server_context.worker_index - 1,
"num_workers": num_workers,
}

nubison_mlflow_model.load_model(context)

@bentoml.api
@wraps(nubison_mlflow_model.get_nubison_model_infer_method())
Expand Down
2 changes: 2 additions & 0 deletions nubison_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .Model import (
ENV_VAR_MLFLOW_MODEL_URI,
ENV_VAR_MLFLOW_TRACKING_URI,
ModelContext,
NubisonMLFlowModel,
NubisonModel,
register,
Expand All @@ -14,6 +15,7 @@
__all__ = [
"ENV_VAR_MLFLOW_MODEL_URI",
"ENV_VAR_MLFLOW_TRACKING_URI",
"ModelContext",
"NubisonModel",
"NubisonMLFlowModel",
"register",
Expand Down
5 changes: 3 additions & 2 deletions test/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from nubison_model import (
ENV_VAR_MLFLOW_MODEL_URI,
ModelContext,
NubisonModel,
Service,
build_inference_service,
Expand All @@ -14,7 +15,7 @@ def test_register_and_serve_model(mlflow_server):
"""

class DummyModel(NubisonModel):
def load_model(self):
def load_model(self, context: ModelContext):
# Try to read the contents of the artifact file
with open("./fixtures/bar.txt", "r") as f:
self.loaded = f.read()
Expand Down Expand Up @@ -44,7 +45,7 @@ def test_register_and_test_model(mlflow_server):
"""

class DummyModel(NubisonModel):
def load_model(self):
def load_model(self, context: ModelContext):
# Try to read the contents of the artifact file
with open("./fixtures/bar.txt", "r") as f:
self.loaded = f.read()
Expand Down
16 changes: 12 additions & 4 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from mlflow.tracking import MlflowClient

from nubison_model import NubisonModel, register
from nubison_model import NubisonModel, register, ModelContext
from nubison_model.Model import _make_artifact_dir_dict, _package_list_from_file
from test.utils import (
get_run_id_from_model_uri,
Expand All @@ -21,7 +21,11 @@ def test_register_model(mlflow_server):

# define a simple model (for example purposes, using a dummy model)
class DummyModel(NubisonModel):
pass
def load_model(self, context: ModelContext):
pass

def infer(self, input):
pass

# configure the code directories
artifact_dirs = ["src1", "src2"]
Expand Down Expand Up @@ -56,7 +60,7 @@ class WrongModel:
pass

class RightModel(NubisonModel):
def load_model(self):
def load_model(self, context: ModelContext):
pass

def infer(self, input):
Expand Down Expand Up @@ -102,7 +106,11 @@ def test_log_params_and_metrics(mlflow_server):
model_name = "TestLoggedModel"

class DummyModel(NubisonModel):
pass
def load_model(self, context: ModelContext):
pass

def infer(self, input):
pass

# Test parameters and metrics
test_params = {"param1": "value1", "param2": "value2"}
Expand Down
Loading

0 comments on commit 8f73ebb

Please sign in to comment.