Databricks-ML-professional-S02b-Model-Management
This Notebook adds information related to the following requirements:
Preprocessing Logic:
- Describe the basic purpose and user interactions with Model Registry
- Programmatically register a new model or new model version
- Add metadata to a registered model and a registered model version
- Identify, compare, and contrast the available model stages
- Transition, archive, and delete model versions
Download this notebook at format ipynb here.
MLflow Model Registry is a collaborative hub where teams can share ML models, work together from experimentation to online testing and production, integrate with approval and governance workflows, and monitor ML deployments and their performance.
- Is a centralized model store, set of APIs, and UI, to collaboratively manage the full lifecycle of an MLflow Model.
- Provides model lineage (which MLflow Experiment and Run produced the model), model versioning, stage transitions (e.g. from staging to production), annotations (e.g. with comments, tags), and deployment management (e.g. which production jobs have requested a specific model version)
- Features of Model Registry:
- Central Repository: Register MLflow models with the MLflow Model Registry. A registered model has a unique name, version, stage, and other metadata.
- Model Versioning: Automatically keep track of versions for registered models when updated.
- Model Stage: Assigned preset or custom stages to each model version, like “Staging” and “Production” to represent the lifecycle of a model.
- Model Stage Transitions: Record new registration events or changes as activities that automatically log users, changes, and additional metadata such as comments.
- CI/CD Workflow Integration: Record stage transitions, request, review and approve changes as part of CI/CD pipelines for better control and governance.
- Can be managed using the UI or in pure python

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
#
import seaborn as sns
#
import mlflow
#
import logging
import re
mlflow.sklearn.autolog(disable=True)
logging.getLogger("mlflow").setLevel(logging.FATAL)
diamonds_df = sns.load_dataset("diamonds").drop(["cut", "color", "clarity"], axis=1)
#
X_train, X_test, y_train, y_test = train_test_split(diamonds_df.drop(["price"], axis=1), diamonds_df["price"], test_size=0.33)
#
model = LinearRegression().fit(X_train, y_train)
model_path = "sns_diamonds"
#
with mlflow.start_run(run_name="register_diamonds") as run:
mlflow.sklearn.log_model(sk_model =model,
artifact_path=model_path)
Programmaticaly register the latest logged model:
Note that, running multiple times the command below automatically register a new model with a version number incremented by one from last registered version.
# get the latest model
latest_run_id = mlflow.search_runs().sort_values(by="end_time", ascending=False).head(1)['run_id'][0]
#
mlflow.register_model(f"runs:/{latest_run_id}/{model_path}", name="lr_sns_diamonds");
Registered model 'lr_sns_diamonds' already exists. Creating a new version of this model... Created version '11' of model 'lr_sns_diamonds'.
Alternatively, a newly logged model can be logged automatically by using parameter registered_model_name
in mlflow.sklearn.log_model
:
diamonds_df = sns.load_dataset("diamonds").drop(["cut", "color", "clarity"], axis=1)
#
X_train, X_test, y_train, y_test = train_test_split(diamonds_df.drop(["price"], axis=1), diamonds_df["price"], test_size=0.33)
#
model = LinearRegression().fit(X_train, y_train)
model_path = "sns_diamonds"
#
with mlflow.start_run(run_name="register_diamonds") as run:
mlflow.sklearn.log_model(sk_model =model,
artifact_path=model_path,
registered_model_name="lr_sns_diamonds")
Registered model 'lr_sns_diamonds' already exists. Creating a new version of this model... Created version '12' of model 'lr_sns_diamonds'.
Alternatively, a new registered model can be created from scratch, and then filled with a model from an existing run:
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
# Register model name in the model registry
client = mlflow.MlflowClient()
client.create_registered_model("sns_diamonds_create")
# Create a new version of the rfr model under the registered model name
desc = "A new version of sns diamonds dataset linear regressions model"
runs_uri = f"runs:/{latest_run_id}/{model_path}"
model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)
mv = client.create_model_version("sns_diamonds_create", model_src, latest_run_id, description=desc)
client = mlflow.MlflowClient()
Registered model metadata can be listed:
for val in client.get_registered_model("sns_diamonds_create"):
print(val)
# Set registered model tag
client.set_registered_model_tag("sns_diamonds_create", "task", "classification")
for val in client.get_registered_model("sns_diamonds_create"):
print(val)
# Set model version tag
client.set_model_version_tag("sns_diamonds_create", "2", "validation_status", "approved")
for val in client.get_registered_model("sns_diamonds_create"):
print(val)
# Add or update description
client.update_model_version(
name="sns_diamonds_create",
version=1,
description="This is the first version of sns_diamonds_create model",
)
for version in client.search_model_versions(filter_string="name='sns_diamonds_create'"):
print(f"Description of version {version.version} of the model: {version.description}")
Description of version 1 of the model: This is the first version of sns_diamonds_create model Description of version 2 of the model: A new version of sns diamonds dataset linear regressions model
See more information on how to update registered model here.
- The MLflow Model Registry defines several model stages:
- None: the model is in developement
- Staging: the model is to be tested
- Production: the model is tested, validated and in production
- Archived: backup of the model
By default, when registered, models are associated with tag None.
List registered models and their latest version:
client = mlflow.MlflowClient()
#
registered_models = []
for model in client.search_registered_models():
if len(model.latest_versions)>0:
registered_models.append((model.latest_versions[0].name,
model.latest_versions[0].run_id,
model.latest_versions[0].version,
model.latest_versions[0].current_stage,
model.latest_versions[0].status,
model.latest_versions[0].tags,
model.latest_versions[0].source))
#
display(spark.createDataFrame(registered_models, ['name', 'run_id', 'latest_version', 'current_stage', 'status', 'tags', 'source']))
name | run_id | latest_version | current_stage | status | tags | source |
---|---|---|---|---|---|---|
lr_sns_diamonds | 622df96ffd7f48589094b425e71743e1 | 4 | Archived | READY | Map() | dbfs:/databricks/mlflow-tracking/3581746151855359/622df96ffd7f48589094b425e71743e1/artifacts/sns_diamonds |
lr_sns_diamonds_v1 | 38ddf31bed524b7baf66eb93a8962754 | 1 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3581746151855359/38ddf31bed524b7baf66eb93a8962754/artifacts/sns_diamonds |
ml_flow_AddN_test | 0620a4cfcdcd407d92bc7680cc6312f7 | 1 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3515657616447348/0620a4cfcdcd407d92bc7680cc6312f7/artifacts/add5_model |
mllib_model | 6ccc766ac64c4250b3302487dfabf2da | 2 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3601260417884136/6ccc766ac64c4250b3302487dfabf2da/artifacts/model |
register-LR-All-Features-test | 2 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/1878608598223959/512a0de8368b446bb663288c06ed0858/artifacts/model | |
registered-5-model | 17c2fb1fbdcf4bdb97cd9681cf9cca8d | 1 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3079599880510783/17c2fb1fbdcf4bdb97cd9681cf9cca8d/artifacts/pyfunc-model |
scikit-learn_model | 9403ddccf7b64568a462bbba2cdccaee | 5 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3601260417884136/9403ddccf7b64568a462bbba2cdccaee/artifacts/model |
sns_diamonds_create | 5f6052cad9f34deba6f934f69707a093 | 2 | None | READY | Map(validation_status -> approved) | dbfs:/databricks/mlflow-tracking/3581746151855359/5f6052cad9f34deba6f934f69707a093/artifacts/sns_diamonds |
trained_with_4_features | ce5848dbbaec406daa95f057f14722bb | 6 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3601260417884136/ce5848dbbaec406daa95f057f14722bb/artifacts/trained_with_4_features |
trained_with_all_features | ce5848dbbaec406daa95f057f14722bb | 5 | None | READY | Map() | dbfs:/databricks/mlflow-tracking/3601260417884136/ce5848dbbaec406daa95f057f14722bb/artifacts/trained_with_all_features |
webhook_diamonds | ed6f91126eb149e7bf39c024da865a00 | 1 | Staging | READY | Map() | dbfs:/databricks/mlflow-tracking/1352035400533066/ed6f91126eb149e7bf39c024da865a00/artifacts/webhook-model |
Info about one specific registered model and its latest version:
client.get_registered_model('lr_sns_diamonds')
Info about one specific model and a given version:
client.get_model_version('lr_sns_diamonds', 2)
Transition a specific version of a registered model to a given stage. Valid values for stage are: Production, Staging, Archived, None
client.transition_model_version_stage('lr_sns_diamonds', 3, 'Production')
Out[41]: <ModelVersion: creation_timestamp=1699465110132, current_stage='Production', description='', last_updated_timestamp=1699605828059, name='lr_sns_diamonds', run_id='e369770189f5436fb28cd66f2a898c1d', run_link='', source='dbfs:/databricks/mlflow-tracking/3581746151855359/e369770189f5436fb28cd66f2a898c1d/artifacts/sns_diamonds', status='READY', status_message='', tags={}, user_id='2329071338839022', version='3'>
Archive a specific version of a registered model is the same command:
client.transition_model_version_stage('lr_sns_diamonds', 4, 'Archived')
Out[42]: <ModelVersion: creation_timestamp=1699465347090, current_stage='Archived', description='', last_updated_timestamp=1699605913340, name='lr_sns_diamonds', run_id='622df96ffd7f48589094b425e71743e1', run_link='', source='dbfs:/databricks/mlflow-tracking/3581746151855359/622df96ffd7f48589094b425e71743e1/artifacts/sns_diamonds', status='READY', status_message='', tags={}, user_id='2329071338839022', version='4'>
Delete a registered model version:
client.delete_model_version('lr_sns_diamonds', 6)
Get a list of available properties and methods:
display(spark.createDataFrame([{'props_and_methods': method} for method in dir(client) if method[0]!='_']))
props_and_methods |
---|
create_experiment |
create_model_version |
create_registered_model |
create_run |
delete_experiment |
delete_model_version |
delete_model_version_tag |
delete_registered_model |
delete_registered_model_tag |
delete_run |
delete_tag |
download_artifacts |
get_experiment |
get_experiment_by_name |
get_latest_versions |
get_metric_history |
get_model_version |
get_model_version_download_uri |
get_model_version_stages |
get_registered_model |
get_run |
list_artifacts |
list_experiments |
list_registered_models |
list_run_infos |
log_artifact |
log_artifacts |
log_batch |
log_dict |
log_figure |
log_image |
log_metric |
log_param |
log_text |
rename_experiment |
rename_registered_model |
restore_experiment |
restore_run |
search_experiments |
search_model_versions |
search_registered_models |
search_runs |
set_experiment_tag |
set_model_version_tag |
set_registered_model_tag |
set_tag |
set_terminated |
tracking_uri |
transition_model_version_stage |
update_model_version |
update_registered_model |