-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_pipeline.py
31 lines (26 loc) · 920 Bytes
/
run_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import click
from pipelines.training_pipeline import customer_satisfaction_training_pipeline
from zenml.integrations.mlflow.mlflow_utils import get_tracking_uri
@click.command()
@click.option(
"--model_type",
"-m",
type=click.Choice(["lightgbm", "randomforest"]),#, "xgboost"]),
default="randomforest",
help="Here you can choose what type of model should be trained.",
)
def main(model_type: str):
(
customer_satisfaction_training_pipeline.with_options(
config_path="config.yaml"
)(model_type)
)
print(
"Now run \n "
f" mlflow ui --backend-store-uri '{get_tracking_uri()}'\n"
"To inspect your experiment runs within the mlflow UI.\n"
"You can find your runs tracked within the `mlflow_example_pipeline`"
"experiment. Here you'll also be able to compare the two runs.)"
)
if __name__ == "__main__":
main()