Skip to content

Commit

Permalink
Feat/log params metrics (#7)
Browse files Browse the repository at this point in the history
* feat: log params and metrics

* feat: log params and metrics
  • Loading branch information
kyuwoo-choi authored Nov 29, 2024
1 parent a652dbb commit df0fdfb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 17 deletions.
41 changes: 24 additions & 17 deletions example/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -49,29 +49,29 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/21 11:55:16 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
"2024/11/28 14:29:24 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n",
"2024/11/21 11:55:16 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n",
"2024/11/28 14:29:24 WARNING mlflow.models.model: Model logged without a signature and input example. Please set `input_example` parameter when logging the model to auto infer the model signature.\n",
"Registered model 'Default' already exists. Creating a new version of this model...\n",
"2024/11/21 11:55:16 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: Default, version 139\n",
"Created version '139' of model 'Default'.\n",
"2024/11/21 11:55:16 INFO mlflow.tracking._tracking_service.client: 🏃 View run hilarious-hound-941 at: http://127.0.0.1:5000/#/experiments/0/runs/7a07fb6fa06549558fbbb35778a3a938.\n",
"2024/11/21 11:55:16 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/0.\n"
"2024/11/28 14:29:24 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: Default, version 141\n",
"Created version '141' of model 'Default'.\n",
"2024/11/28 14:29:24 INFO mlflow.tracking._tracking_service.client: 🏃 View run bittersweet-wren-934 at: http://127.0.0.1:5000/#/experiments/0/runs/f6de73428f38482f9d4ee6ae86494793.\n",
"2024/11/28 14:29:24 INFO mlflow.tracking._tracking_service.client: 🧪 View experiment at: http://127.0.0.1:5000/#/experiments/0.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model registered: runs:/7a07fb6fa06549558fbbb35778a3a938/\n"
"Model registered: runs:/f6de73428f38482f9d4ee6ae86494793/\n"
]
}
],
Expand All @@ -81,8 +81,8 @@
"\n",
"# Register the user model\n",
"# The `artifact_dirs` argument specifies the folders containing the files used by the model class.\n",
"model_id = register(UserModel(), artifact_dirs=\"src\")\n",
"print(f\"Model registered: {model_id}\")"
"model_id = register(UserModel(), artifact_dirs=\"src\", params={\"desc\": \"This is a test model\"}, metrics={\"train_score\": 0.9, \"val_score\": 0.8, \"test_score\": 0.7})\n",
"print(f\"Model registered: {model_id}\")\n"
]
},
{
Expand All @@ -94,14 +94,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/21 11:55:18 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
"2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n"
]
Expand All @@ -110,20 +110,20 @@
"name": "stderr",
"output_type": "stream",
"text": [
"2024/11/21 11:55:18 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
"2024/11/28 14:29:26 WARNING mlflow.utils.requirements_utils: Detected one or more mismatches between the model's dependencies and the current Python environment:\n",
" - nubison-model (current: 0.0.2.dev3+3e1558a.20241118053748, required: nubison-model==0.0.1)\n",
"To fix the mismatches, call `mlflow.pyfunc.get_model_dependencies(model_uri)` to fetch the model's environment and install dependencies using the resulting environment file.\n",
"2024-11-21 11:55:18,653 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n",
"2024-11-28 14:29:26,359 - SimpleLinearModel - INFO - Weights loaded successfully from ./src/weights.txt.\n",
"INFO:SimpleLinearModel:Weights loaded successfully from ./src/weights.txt.\n",
"2024-11-21 11:55:18,671 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n",
"2024-11-28 14:29:26,367 - SimpleLinearModel - INFO - Calculating the result of the linear model with x1=3.1, x2=2.0.\n",
"INFO:SimpleLinearModel:Calculating the result of the linear model with x1=3.1, x2=2.0.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prepared artifact: src -> /tmp/tmpttu4qyi4/artifacts/src\n",
"Prepared artifact: src -> /tmp/tmprsct2rwt/artifacts/src\n",
"The result of the linear model is 4.35.\n"
]
}
Expand All @@ -139,6 +139,13 @@
" print(f\"The result of the linear model is {result.json()['y']}.\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
8 changes: 8 additions & 0 deletions nubison_model/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def register(
model_name: Optional[str] = None,
mlflow_uri: Optional[str] = None,
artifact_dirs: Optional[str] = None,
params: Optional[dict[str, Any]] = None,
metrics: Optional[dict[str, float]] = None,
):
# Check if the model implements the Model protocol
if not isinstance(model, NubisonModel):
Expand All @@ -155,6 +157,12 @@ def register(

# Start a new MLflow run
with mlflow.start_run() as run:
# Log parameters and metrics
if params:
mlflow.log_params(params)
if metrics:
mlflow.log_metrics(metrics)

# Log the model to MLflow
model_info: ModelInfo = mlflow.pyfunc.log_model(
registered_model_name=model_name,
Expand Down
36 changes: 36 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,39 @@ def test_artifact_dirs_from_env():
with temporary_env({"ARTIFACT_DIRS": "src"}):
assert _make_artifact_dir_dict(None) == {"src": "src"}
assert _make_artifact_dir_dict("src,test") == {"src": "src", "test": "test"}


def test_log_params_and_metrics(mlflow_server):
"""
Test logging parameters and metrics to MLflow.
"""
model_name = "TestLoggedModel"

class DummyModel(NubisonModel):
pass

# Test parameters and metrics
test_params = {"param1": "value1", "param2": "value2"}
test_metrics = {"metric1": 1.0, "metric2": 2.0}

# Register model with params and metrics
model_uri = register(
DummyModel(),
model_name=model_name,
params=test_params,
metrics=test_metrics,
)

# Extract run_id from model_uri (format: "runs:/run_id/path")
run_id = model_uri.split("/")[1]

# Get the run information from MLflow
client = MlflowClient()
run = client.get_run(run_id)

assert set(test_params.items()) <= set(
run.data.params.items()
), "Not all parameters were logged correctly"
assert set(test_metrics.items()) <= set(
run.data.metrics.items()
), "Not all metrics were logged correctly"

0 comments on commit df0fdfb

Please sign in to comment.