diff --git a/container/Dockerfile b/container/Dockerfile index 817117b..49fb431 100644 --- a/container/Dockerfile +++ b/container/Dockerfile @@ -13,7 +13,11 @@ SHELL ["bash", "-c"] # Set the working directory WORKDIR /app -# Copy the entrypoint script into the container +# Add label descriptions +LABEL org.opencontainers.image.title="nubison-model" \ + org.opencontainers.image.description="A container image for nubison-model."\ + org.opencontainers.image.source="https://github.com/nubison/nubison-model" + COPY start_server.sh /app/start_server.sh RUN chmod +x /app/start_server.sh @@ -22,6 +26,7 @@ ENV MLFLOW_TRACKING_URI="" ENV MLFLOW_MODEL_URI="" ENV DEBUG="" ENV PORT=3000 +ENV NUM_WORKERS=2 # Expose the port specified by the PORT environment variable EXPOSE ${PORT} diff --git a/example/README.md b/example/README.md index 5b699d8..4f6c1c4 100644 --- a/example/README.md +++ b/example/README.md @@ -38,6 +38,10 @@ The `model.ipynb` file shows how to register a user model. It contains the follo - Use this method to prepare the model for inference which can be time-consuming. - This method is called once when the model inference server starts. +- The `load_model` method receives a `ModelContext` dictionary containing: + - `worker_index`: Index of the worker process (0-based) for parallel processing + - `num_workers`: Total number of workers running the model +- This information is particularly useful for GPU initialization in parallel setups, where you can map specific workers to specific GPU devices. - The path to the model weights file can be specified relative. #### #### `infer` method diff --git a/example/model.ipynb b/example/model.ipynb index baa4917..c55fb2c 100644 --- a/example/model.ipynb +++ b/example/model.ipynb @@ -16,15 +16,16 @@ "# The `NubisonModel` class serves as a base class for creating custom user model classes.\n", "# Note that modules required by UserModel must be imported within the NubisonModel class.\n", "# This is because the UserModel is cloudpickled, and using modules imported outside of the NubisonModel class will cause errors.\n", - "from nubison_model import NubisonModel\n", + "from nubison_model import NubisonModel, ModelContext\n", "\n", "class UserModel(NubisonModel):\n", - " \"\"\"\n", - " A user model that extends the `NubisonModel` base class.\n", - " \"\"\"\n", - " def load_model(self) -> None:\n", - " \"\"\"\n", - " This method is used to load the model weights from the file.\n", + " \"\"\"A user model that extends the NubisonModel base class.\"\"\"\n", + " \n", + " def load_model(self, context: ModelContext) -> None:\n", + " \"\"\"Load the model weights from the file.\n", + " \n", + " Args:\n", + " context: Contains worker_index (0-based) for GPU initialization in parallel setups.\n", " \"\"\"\n", " try:\n", " # Import the SimpleLinearModel class from the src directory\n", @@ -51,29 +52,29 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024/11/28 16:58:19 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n", + "2025/01/02 15:54:12 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n", " - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n", "To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n", - "2024/11/28 16:58:19 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n", - "Registered model 'nubison_model' already exists. Creating a new version of this model...\n", - "2024/11/28 16:58:20 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: nubison_model, version 8\n", - "Created version '8' of model 'nubison_model'.\n", - "2024/11/28 16:58:20 INFO mlflow.tracking._tracking_service.client: ๐Ÿƒ View run bustling-mare-593 at: https://model.nubison.io/#/experiments/1207/runs/dbebfd0f99594a0fa1c67a6a00b3e270.\n", - "2024/11/28 16:58:20 INFO mlflow.tracking._tracking_service.client: ๐Ÿงช View experiment at: https://model.nubison.io/#/experiments/1207.\n" + "2025/01/02 15:54:12 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n", + "Registered model 'Default' already exists. Creating a new version of this model...\n", + "2025/01/02 15:54:12 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: Default, version 144\n", + "Created version '144' of model 'Default'.\n", + "2025/01/02 15:54:12 INFO mlflow.tracking._tracking_service.client: ๐Ÿƒ View run rumbling-bass-282 at: http://127.0.0.1:5000/#/experiments/0/runs/ee369aafa91c4753b7bb067acf466a9c.\n", + "2025/01/02 15:54:12 INFO mlflow.tracking._tracking_service.client: ๐Ÿงช View experiment at: http://127.0.0.1:5000/#/experiments/0.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Model registered: runs:/dbebfd0f99594a0fa1c67a6a00b3e270/\n" + "Model registered: runs:/ee369aafa91c4753b7bb067acf466a9c/\n" ] } ], @@ -83,7 +84,7 @@ "\n", "# Register the user model\n", "# The `artifact_dirs` argument specifies the folders containing the files used by the model class.\n", - "model_id = register(UserModel(), mlflow_uri=\"https://model.nubison.io\", model_name=\"nubison_model\", artifact_dirs=\"src\", params={\"desc\": \"This is a test model\"}, metrics={\"train\": 0.9, \"validation\": 0.8, \"test\": 0.7})\n", + "model_id = register(UserModel(), artifact_dirs=\"src\", params={\"desc\": \"This is a test model\"}, metrics={\"train\": 0.9, \"validation\": 0.8, \"test\": 0.7})\n", "print(f\"Model registered: {model_id}\")\n" ] }, @@ -96,14 +97,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n", + "2025/01/02 15:54:15 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n", " - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n", "To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n" ] @@ -112,12 +113,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n", - " - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n", - "To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n", - "2024-11-28 14:29:26,359 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n", + "2025-01-02 15:54:15,604 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n", "INFO:SimpleLinearModel:Weights loaded successfully from ./src/weights.txt.\n", - "2024-11-28 14:29:26,367 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n", + "2025-01-02 15:54:15,621 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n", "INFO:SimpleLinearModel:Calculating the result of the linear model with x1=3.1, x2=2.0.\n" ] }, @@ -125,7 +123,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Prepared artifact: src -> /tmp/tmprsct2rwt/artifacts/src\n", + "Prepared artifact: src -> /tmp/tmpamgz0yua/artifacts/src\n", "The result of the linear model is 4.35.\n" ] } @@ -152,7 +150,7 @@ ], "metadata": { "kernelspec": { - "display_name": ".venv", + "display_name": "testmodel", "language": "python", "name": "python3" }, diff --git a/nubison_model/Model.py b/nubison_model/Model.py index c521333..9406176 100644 --- a/nubison_model/Model.py +++ b/nubison_model/Model.py @@ -1,7 +1,7 @@ from importlib.metadata import distributions from os import getenv, path, symlink from sys import version_info as py_version_info -from typing import Any, List, Optional, Protocol, runtime_checkable +from typing import Any, List, Optional, Protocol, TypedDict, runtime_checkable import mlflow from mlflow.models.model import ModelInfo @@ -14,11 +14,61 @@ DEFAULT_ARTIFACT_DIRS = "" # Default code paths comma-separated +class ModelContext(TypedDict): + """Context information passed to model during loading. + + Attributes: + worker_index: Index of the worker process running the model. Used to identify + which worker is running the model in a parallel server setup. Starts from 0. + Even in a single server process setup, this will be 0. This is particularly + useful for GPU initialization as you can map specific workers to specific + GPU devices. + num_workers: Number of workers running the model. + """ + + worker_index: int + + num_workers: int + + @runtime_checkable class NubisonModel(Protocol): - def load_model(self) -> None: ... + """Protocol defining the interface for user-defined models. + + Your model class must implement this protocol by providing: + 1. load_model method - Called once at startup to initialize the model + 2. infer method - Called for each inference request + """ + + def load_model(self, context: ModelContext) -> None: + """Initialize and load the model. + + This method is called once when the model server starts up. + Use it to load model weights and initialize any resources needed for inference. - def infer(self, input: Any) -> Any: ... + Args: + context: A dictionary containing worker information: + - worker_index: Index of the worker process (0-based) + - num_workers: Total number of workers running the model + This information is particularly useful for GPU initialization + in parallel setups, where you can map specific workers to + specific GPU devices. + """ + ... + + def infer(self, input: Any) -> Any: + """Perform inference on the input. + + This method is called for each inference request. + + Args: + input: The input data to perform inference on. + Can be of any type that your model accepts. + + Returns: + The inference result. Can be of any type that your model produces. + """ + ... class NubisonMLFlowModel(PythonModel): @@ -60,8 +110,8 @@ def predict(self, context, model_input): def get_nubison_model(self): return self._nubison_model - def load_model(self): - self._nubison_model.load_model() + def load_model(self, context: ModelContext): + self._nubison_model.load_model(context) def infer(self, *args, **kwargs) -> Any: return self._nubison_model.infer(*args, **kwargs) diff --git a/nubison_model/Service.py b/nubison_model/Service.py index 92c9528..48d7a58 100644 --- a/nubison_model/Service.py +++ b/nubison_model/Service.py @@ -17,6 +17,9 @@ ) from nubison_model.utils import temporary_cwd +ENV_VAR_NUM_WORKERS = "NUM_WORKERS" +DEFAULT_NUM_WORKERS = 1 + def load_nubison_mlflow_model(mlflow_tracking_uri, mlflow_model_uri): if not mlflow_tracking_uri: @@ -63,12 +66,14 @@ def build_inference_service( ) mlflow_model_uri = mlflow_model_uri or getenv(ENV_VAR_MLFLOW_MODEL_URI) or "" + num_workers = int(getenv(ENV_VAR_NUM_WORKERS) or DEFAULT_NUM_WORKERS) + nubison_mlflow_model = load_nubison_mlflow_model( mlflow_tracking_uri=mlflow_tracking_uri, mlflow_model_uri=mlflow_model_uri, ) - @bentoml.service + @bentoml.service(workers=num_workers) class BentoMLService: """BentoML Service for serving machine learning models.""" @@ -81,7 +86,20 @@ def __init__(self): Raises: RuntimeError: Error loading model from the model registry """ - nubison_mlflow_model.load_model() + + # Set default worker index to 1 in case of no bentoml server context is available + # For example, when running with test client + context = { + "worker_index": 0, + "num_workers": 1, + } + if bentoml.server_context.worker_index is not None: + context = { + "worker_index": bentoml.server_context.worker_index - 1, + "num_workers": num_workers, + } + + nubison_mlflow_model.load_model(context) @bentoml.api @wraps(nubison_mlflow_model.get_nubison_model_infer_method()) diff --git a/nubison_model/__init__.py b/nubison_model/__init__.py index 1f6c4df..5ea3eab 100644 --- a/nubison_model/__init__.py +++ b/nubison_model/__init__.py @@ -5,6 +5,7 @@ from .Model import ( ENV_VAR_MLFLOW_MODEL_URI, ENV_VAR_MLFLOW_TRACKING_URI, + ModelContext, NubisonMLFlowModel, NubisonModel, register, @@ -14,6 +15,7 @@ __all__ = [ "ENV_VAR_MLFLOW_MODEL_URI", "ENV_VAR_MLFLOW_TRACKING_URI", + "ModelContext", "NubisonModel", "NubisonMLFlowModel", "register", diff --git a/test/test_integration.py b/test/test_integration.py index 62c6826..c6158b9 100644 --- a/test/test_integration.py +++ b/test/test_integration.py @@ -1,5 +1,6 @@ from nubison_model import ( ENV_VAR_MLFLOW_MODEL_URI, + ModelContext, NubisonModel, Service, build_inference_service, @@ -14,7 +15,7 @@ def test_register_and_serve_model(mlflow_server): """ class DummyModel(NubisonModel): - def load_model(self): + def load_model(self, context: ModelContext): # Try to read the contents of the artifact file with open("./fixtures/bar.txt", "r") as f: self.loaded = f.read() @@ -44,7 +45,7 @@ def test_register_and_test_model(mlflow_server): """ class DummyModel(NubisonModel): - def load_model(self): + def load_model(self, context: ModelContext): # Try to read the contents of the artifact file with open("./fixtures/bar.txt", "r") as f: self.loaded = f.read() diff --git a/test/test_model.py b/test/test_model.py index a507deb..5e376b3 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -3,7 +3,7 @@ import pytest from mlflow.tracking import MlflowClient -from nubison_model import NubisonModel, register +from nubison_model import NubisonModel, register, ModelContext from nubison_model.Model import _make_artifact_dir_dict, _package_list_from_file from test.utils import ( get_run_id_from_model_uri, @@ -21,7 +21,11 @@ def test_register_model(mlflow_server): # define a simple model (for example purposes, using a dummy model) class DummyModel(NubisonModel): - pass + def load_model(self, context: ModelContext): + pass + + def infer(self, input): + pass # configure the code directories artifact_dirs = ["src1", "src2"] @@ -56,7 +60,7 @@ class WrongModel: pass class RightModel(NubisonModel): - def load_model(self): + def load_model(self, context: ModelContext): pass def infer(self, input): @@ -102,7 +106,11 @@ def test_log_params_and_metrics(mlflow_server): model_name = "TestLoggedModel" class DummyModel(NubisonModel): - pass + def load_model(self, context: ModelContext): + pass + + def infer(self, input): + pass # Test parameters and metrics test_params = {"param1": "value1", "param2": "value2"} diff --git a/test/test_service.py b/test/test_service.py index 7d04c35..583f47d 100644 --- a/test/test_service.py +++ b/test/test_service.py @@ -4,7 +4,14 @@ from PIL.Image import Image from PIL.Image import open as open_image -from nubison_model import NubisonMLFlowModel, build_inference_service, test_client +from nubison_model import ( + ModelContext, + NubisonMLFlowModel, + build_inference_service, + test_client, +) +from nubison_model.Service import DEFAULT_NUM_WORKERS +from test.utils import temporary_env def test_raise_runtime_error_on_missing_env(): @@ -17,7 +24,7 @@ class DummyModel: def infer(self, test: str): return test - def load_model(self): + def load_model(self, context: ModelContext): pass with patch( @@ -33,7 +40,7 @@ class DummyModel: def infer(self, test: str): return test - def load_model(self): + def load_model(self, context: ModelContext): pass with patch( @@ -51,7 +58,7 @@ class DummyModel: def infer(self, test: Image): return test.size - def load_model(self): + def load_model(self, context: ModelContext): pass with patch( @@ -69,5 +76,47 @@ def load_model(self): assert response.json() == [100, 100] +def test_model_context(): + class DummyModel: + def __init__(self): + self.context = None + + def infer(self, test: str): + return test + + def load_model(self, context: ModelContext): + self.context = context + + with patch( + "nubison_model.Service.load_nubison_mlflow_model" + ) as mock_load_nubison_mlflow_model: + dummy_model = DummyModel() + mock_load_nubison_mlflow_model.return_value = NubisonMLFlowModel(dummy_model) + + # + with patch("bentoml.server_context.worker_index", 1), temporary_env({}): + service = build_inference_service()() + assert dummy_model.context == { + "worker_index": 0, + "num_workers": DEFAULT_NUM_WORKERS, + }, "Default num_workers should be applied" + + with patch("bentoml.server_context.worker_index", 1), temporary_env( + {"NUM_WORKERS": "8"} + ): + service = build_inference_service()() + assert dummy_model.context == { + "worker_index": 0, + "num_workers": 8, + }, "Custom num_workers should be applied" + + with patch("bentoml.server_context.worker_index", None), temporary_env({}): + service = build_inference_service()() + assert dummy_model.context == { + "worker_index": 0, + "num_workers": 1, + }, "When worker_index is unavailable, both worker_index and num_workers should be set to 1" + + # Ignore the test_client from being collected by pytest setattr(test_client, "__test__", False)