Mlflow
zenml.integrations.mlflow
special
Initialization for the ZenML MLflow integration.
The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.
MlflowIntegration (Integration)
Definition of MLflow integration for ZenML.
Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
"""Definition of MLflow integration for ZenML."""
NAME = MLFLOW
REQUIREMENTS_IGNORED_ON_UNINSTALL = [
"python-rapidjson",
"pydantic",
"numpy",
"pandas",
]
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
"""Method to get the requirements for the integration.
Args:
target_os: The target operating system to get the requirements for.
Returns:
A list of requirements.
"""
from zenml.integrations.numpy import NumpyIntegration
from zenml.integrations.pandas import PandasIntegration
reqs = [
"mlflow>=2.1.1,<3",
# TODO: remove this requirement once rapidjson is fixed
"python-rapidjson<1.15",
# When you do:
# pip install zenml
# You get all our required dependencies. However, if you follow it
# with:
# zenml integration install mlflow
# This downgrades pydantic to v1 even though mlflow does not have
# any issues with v2. This is why we have to pin it here so a
# downgrade will not happen.
"pydantic>=2.8.0,<2.9.0",
]
if sys.version_info.minor >= 12:
logger.debug(
"The MLflow integration on Python 3.12 and above is not yet "
"fully supported: The extra dependencies 'mlserver' and "
"'mlserver-mlflow' will be skipped."
)
else:
reqs.extend([
"mlserver>=1.3.3",
"mlserver-mlflow>=1.3.3",
])
reqs.extend(NumpyIntegration.get_requirements(target_os=target_os))
reqs.extend(PandasIntegration.get_requirements(target_os=target_os))
return reqs
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.mlflow.flavors import (
MLFlowExperimentTrackerFlavor,
MLFlowModelDeployerFlavor,
MLFlowModelRegistryFlavor,
)
return [
MLFlowModelDeployerFlavor,
MLFlowExperimentTrackerFlavor,
MLFlowModelRegistryFlavor,
]
activate()
classmethod
Activate the MLflow integration.
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
"""Activate the MLflow integration."""
from zenml.integrations.mlflow import services # noqa
flavors()
classmethod
Declare the stack component flavors for the MLflow integration.
Returns:
Type | Description |
---|---|
List[Type[zenml.stack.flavor.Flavor]] |
List of stack component flavors for this integration. |
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
"""Declare the stack component flavors for the MLflow integration.
Returns:
List of stack component flavors for this integration.
"""
from zenml.integrations.mlflow.flavors import (
MLFlowExperimentTrackerFlavor,
MLFlowModelDeployerFlavor,
MLFlowModelRegistryFlavor,
)
return [
MLFlowModelDeployerFlavor,
MLFlowExperimentTrackerFlavor,
MLFlowModelRegistryFlavor,
]
get_requirements(target_os=None)
classmethod
Method to get the requirements for the integration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
target_os |
Optional[str] |
The target operating system to get the requirements for. |
None |
Returns:
Type | Description |
---|---|
List[str] |
A list of requirements. |
Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def get_requirements(cls, target_os: Optional[str] = None) -> List[str]:
"""Method to get the requirements for the integration.
Args:
target_os: The target operating system to get the requirements for.
Returns:
A list of requirements.
"""
from zenml.integrations.numpy import NumpyIntegration
from zenml.integrations.pandas import PandasIntegration
reqs = [
"mlflow>=2.1.1,<3",
# TODO: remove this requirement once rapidjson is fixed
"python-rapidjson<1.15",
# When you do:
# pip install zenml
# You get all our required dependencies. However, if you follow it
# with:
# zenml integration install mlflow
# This downgrades pydantic to v1 even though mlflow does not have
# any issues with v2. This is why we have to pin it here so a
# downgrade will not happen.
"pydantic>=2.8.0,<2.9.0",
]
if sys.version_info.minor >= 12:
logger.debug(
"The MLflow integration on Python 3.12 and above is not yet "
"fully supported: The extra dependencies 'mlserver' and "
"'mlserver-mlflow' will be skipped."
)
else:
reqs.extend([
"mlserver>=1.3.3",
"mlserver-mlflow>=1.3.3",
])
reqs.extend(NumpyIntegration.get_requirements(target_os=target_os))
reqs.extend(PandasIntegration.get_requirements(target_os=target_os))
return reqs
experiment_trackers
special
Initialization of the MLflow experiment tracker.
mlflow_experiment_tracker
Implementation of the MLflow experiment tracker for ZenML.
MLFlowExperimentTracker (BaseExperimentTracker)
Track experiments using MLflow.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
"""Track experiments using MLflow."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the experiment tracker and validate the tracking uri.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._ensure_valid_tracking_uri()
def _ensure_valid_tracking_uri(self) -> None:
"""Ensures that the tracking uri is a valid mlflow tracking uri.
Raises:
ValueError: If the tracking uri is not valid.
"""
tracking_uri = self.config.tracking_uri
if tracking_uri:
valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
if not any(
tracking_uri.startswith(scheme) for scheme in valid_schemes
) and not is_databricks_tracking_uri(tracking_uri):
raise ValueError(
f"MLflow tracking uri does not start with one of the valid "
f"schemes {valid_schemes} or its value is not set to "
f"'databricks'. See "
f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
f"for more information."
)
@property
def config(self) -> MLFlowExperimentTrackerConfig:
"""Returns the `MLFlowExperimentTrackerConfig` config.
Returns:
The configuration.
"""
return cast(MLFlowExperimentTrackerConfig, self._config)
@property
def local_path(self) -> Optional[str]:
"""Path to the local directory where the MLflow artifacts are stored.
Returns:
None if configured with a remote tracking URI, otherwise the
path to the local MLflow artifact store directory.
"""
tracking_uri = self.get_tracking_uri()
if is_remote_mlflow_tracking_uri(tracking_uri):
return None
else:
assert tracking_uri.startswith("file:")
return tracking_uri[5:]
@property
def validator(self) -> Optional["StackValidator"]:
"""Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.
Returns:
An optional `StackValidator`.
"""
if self.config.tracking_uri:
# user specified a tracking uri, do nothing
return None
else:
# try to fall back to a tracking uri inside the zenml artifact
# store. this only works in case of a local artifact store, so we
# make sure to prevent stack with other artifact stores for now
return StackValidator(
custom_validation_function=lambda stack: (
isinstance(stack.artifact_store, LocalArtifactStore),
"MLflow experiment tracker without a specified tracking "
"uri only works with a local artifact store.",
)
)
@property
def settings_class(self) -> Optional[Type["BaseSettings"]]:
"""Settings class for the Mlflow experiment tracker.
Returns:
The settings class.
"""
return MLFlowExperimentTrackerSettings
@staticmethod
def _local_mlflow_backend() -> str:
"""Gets the local MLflow backend inside the ZenML artifact repository directory.
Returns:
The MLflow tracking URI for the local MLflow backend.
"""
client = Client()
artifact_store = client.active_stack.artifact_store
local_mlflow_tracking_uri = os.path.join(artifact_store.path, "mlruns")
if not os.path.exists(local_mlflow_tracking_uri):
os.makedirs(local_mlflow_tracking_uri)
return "file:" + local_mlflow_tracking_uri
def get_tracking_uri(self, as_plain_text: bool = True) -> str:
"""Returns the configured tracking URI or a local fallback.
Args:
as_plain_text: Whether to return the tracking URI as plain text.
Returns:
The tracking URI.
"""
if as_plain_text:
tracking_uri = self.config.tracking_uri
else:
tracking_uri = self.config.model_dump()["tracking_uri"]
return tracking_uri or self._local_mlflow_backend()
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Sets the MLflow tracking uri and credentials.
Args:
info: Info about the step that will be executed.
"""
self.configure_mlflow()
settings = cast(
MLFlowExperimentTrackerSettings,
self.get_settings(info),
)
experiment_name = settings.experiment_name or info.pipeline.name
experiment = self._set_active_experiment(experiment_name)
run_id = self.get_run_id(
experiment_name=experiment_name, run_name=info.run_name
)
tags = settings.tags.copy()
tags.update(self._get_internal_tags())
mlflow.start_run(
run_id=run_id,
run_name=info.run_name,
experiment_id=experiment.experiment_id,
tags=tags,
)
if settings.nested:
mlflow.start_run(
run_name=info.pipeline_step_name, nested=True, tags=tags
)
def get_step_run_metadata(
self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
"""Get component- and step-specific metadata after a step ran.
Args:
info: Info about the step that was executed.
Returns:
A dictionary of metadata.
"""
metadata: Dict[str, Any] = {
METADATA_EXPERIMENT_TRACKER_URL: Uri(
self.get_tracking_uri(as_plain_text=False)
),
}
if run := mlflow.active_run():
metadata["mlflow_run_id"] = run.info.run_id
metadata["mlflow_experiment_id"] = run.info.experiment_id
return metadata
def disable_autologging(self) -> None:
"""Disables MLflow autologging for all supported frameworks."""
frameworks = [
"tensorflow",
"gluon",
"xgboost",
"lightgbm",
"statsmodels",
"spark",
"sklearn",
"fastai",
"pytorch",
]
failed_frameworks = []
for framework in frameworks:
try:
# Correctly prefix the module name with 'mlflow.'
module_name = f"mlflow.{framework}"
# Dynamically import the module corresponding to the framework
module = importlib.import_module(module_name)
# Call the autolog function with disable=True
module.autolog(disable=True)
except ImportError as e:
# only log on mlflow relevant errors
if "mlflow" in e.msg.lower():
failed_frameworks.append(framework)
except Exception:
failed_frameworks.append(framework)
if len(failed_frameworks) > 0:
logger.warning(
f"Failed to disable MLflow autologging for the following frameworks: "
f"{failed_frameworks}."
)
def cleanup_step_run(
self,
info: "StepRunInfo",
step_failed: bool,
) -> None:
"""Stops active MLflow runs and resets the MLflow tracking uri.
Args:
info: Info about the step that was executed.
step_failed: Whether the step failed or not.
"""
status = "FAILED" if step_failed else "FINISHED"
self.disable_autologging()
mlflow_utils.stop_zenml_mlflow_runs(status)
mlflow.set_tracking_uri("")
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
tracking_uri = self.get_tracking_uri()
mlflow.set_tracking_uri(tracking_uri)
if is_databricks_tracking_uri(tracking_uri):
if self.config.databricks_host:
os.environ[DATABRICKS_HOST] = self.config.databricks_host
if self.config.tracking_username:
os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
if self.config.tracking_password:
os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
if self.config.tracking_token:
os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
if self.config.enable_unity_catalog:
mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG)
else:
os.environ[MLFLOW_TRACKING_URI] = tracking_uri
if self.config.tracking_username:
os.environ[MLFLOW_TRACKING_USERNAME] = (
self.config.tracking_username
)
if self.config.tracking_password:
os.environ[MLFLOW_TRACKING_PASSWORD] = (
self.config.tracking_password
)
if self.config.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.config.tracking_insecure_tls else "false"
)
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
"""Gets the if of a run with the given name and experiment.
Args:
experiment_name: Name of the experiment in which to search for the
run.
run_name: Name of the run to search.
Returns:
The id of the run if it exists.
"""
self.configure_mlflow()
experiment_name = self._adjust_experiment_name(experiment_name)
runs = mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=f'tags.mlflow.runName = "{run_name}"',
run_view_type=3,
output_format="list",
)
if not runs:
return None
run: Run = runs[0]
if mlflow_utils.is_zenml_run(run):
return cast(str, run.info.run_id)
else:
return None
def _set_active_experiment(self, experiment_name: str) -> Experiment:
"""Sets the active MLflow experiment.
If no experiment with this name exists, it is created and then
activated.
Args:
experiment_name: Name of the experiment to activate.
Raises:
RuntimeError: If the experiment creation or activation failed.
Returns:
The experiment.
"""
experiment_name = self._adjust_experiment_name(experiment_name)
mlflow.set_experiment(experiment_name=experiment_name)
experiment = mlflow.get_experiment_by_name(experiment_name)
if not experiment:
raise RuntimeError("Failed to set active mlflow experiment.")
return experiment
def _adjust_experiment_name(self, experiment_name: str) -> str:
"""Prepends a slash to the experiment name if using Databricks.
Databricks requires the experiment name to be an absolute path within
the Databricks workspace.
Args:
experiment_name: The experiment name.
Returns:
The potentially adjusted experiment name.
"""
tracking_uri = self.get_tracking_uri()
if (
tracking_uri
and is_databricks_tracking_uri(tracking_uri)
and not experiment_name.startswith("/")
):
return f"/{experiment_name}"
else:
return experiment_name
@staticmethod
def _get_internal_tags() -> Dict[str, Any]:
"""Gets ZenML internal tags for MLflow runs.
Returns:
Internal tags.
"""
return {mlflow_utils.ZENML_TAG_KEY: zenml.__version__}
config: MLFlowExperimentTrackerConfig
property
readonly
Returns the MLFlowExperimentTrackerConfig
config.
Returns:
Type | Description |
---|---|
MLFlowExperimentTrackerConfig |
The configuration. |
local_path: Optional[str]
property
readonly
Path to the local directory where the MLflow artifacts are stored.
Returns:
Type | Description |
---|---|
Optional[str] |
None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory. |
settings_class: Optional[Type[BaseSettings]]
property
readonly
Settings class for the Mlflow experiment tracker.
Returns:
Type | Description |
---|---|
Optional[Type[BaseSettings]] |
The settings class. |
validator: Optional[StackValidator]
property
readonly
Checks the stack has a LocalArtifactStore
if no tracking uri was specified.
Returns:
Type | Description |
---|---|
Optional[StackValidator] |
An optional |
__init__(self, *args, **kwargs)
special
Initialize the experiment tracker and validate the tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*args |
Any |
Variable length argument list. |
() |
**kwargs |
Any |
Arbitrary keyword arguments. |
{} |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize the experiment tracker and validate the tracking uri.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__(*args, **kwargs)
self._ensure_valid_tracking_uri()
cleanup_step_run(self, info, step_failed)
Stops active MLflow runs and resets the MLflow tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
step_failed |
bool |
Whether the step failed or not. |
required |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(
self,
info: "StepRunInfo",
step_failed: bool,
) -> None:
"""Stops active MLflow runs and resets the MLflow tracking uri.
Args:
info: Info about the step that was executed.
step_failed: Whether the step failed or not.
"""
status = "FAILED" if step_failed else "FINISHED"
self.disable_autologging()
mlflow_utils.stop_zenml_mlflow_runs(status)
mlflow.set_tracking_uri("")
configure_mlflow(self)
Configures the MLflow tracking URI and any additional credentials.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
"""Configures the MLflow tracking URI and any additional credentials."""
tracking_uri = self.get_tracking_uri()
mlflow.set_tracking_uri(tracking_uri)
if is_databricks_tracking_uri(tracking_uri):
if self.config.databricks_host:
os.environ[DATABRICKS_HOST] = self.config.databricks_host
if self.config.tracking_username:
os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
if self.config.tracking_password:
os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
if self.config.tracking_token:
os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
if self.config.enable_unity_catalog:
mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG)
else:
os.environ[MLFLOW_TRACKING_URI] = tracking_uri
if self.config.tracking_username:
os.environ[MLFLOW_TRACKING_USERNAME] = (
self.config.tracking_username
)
if self.config.tracking_password:
os.environ[MLFLOW_TRACKING_PASSWORD] = (
self.config.tracking_password
)
if self.config.tracking_token:
os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token
os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
"true" if self.config.tracking_insecure_tls else "false"
)
disable_autologging(self)
Disables MLflow autologging for all supported frameworks.
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def disable_autologging(self) -> None:
"""Disables MLflow autologging for all supported frameworks."""
frameworks = [
"tensorflow",
"gluon",
"xgboost",
"lightgbm",
"statsmodels",
"spark",
"sklearn",
"fastai",
"pytorch",
]
failed_frameworks = []
for framework in frameworks:
try:
# Correctly prefix the module name with 'mlflow.'
module_name = f"mlflow.{framework}"
# Dynamically import the module corresponding to the framework
module = importlib.import_module(module_name)
# Call the autolog function with disable=True
module.autolog(disable=True)
except ImportError as e:
# only log on mlflow relevant errors
if "mlflow" in e.msg.lower():
failed_frameworks.append(framework)
except Exception:
failed_frameworks.append(framework)
if len(failed_frameworks) > 0:
logger.warning(
f"Failed to disable MLflow autologging for the following frameworks: "
f"{failed_frameworks}."
)
get_run_id(self, experiment_name, run_name)
Gets the if of a run with the given name and experiment.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
experiment_name |
str |
Name of the experiment in which to search for the run. |
required |
run_name |
str |
Name of the run to search. |
required |
Returns:
Type | Description |
---|---|
Optional[str] |
The id of the run if it exists. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
"""Gets the if of a run with the given name and experiment.
Args:
experiment_name: Name of the experiment in which to search for the
run.
run_name: Name of the run to search.
Returns:
The id of the run if it exists.
"""
self.configure_mlflow()
experiment_name = self._adjust_experiment_name(experiment_name)
runs = mlflow.search_runs(
experiment_names=[experiment_name],
filter_string=f'tags.mlflow.runName = "{run_name}"',
run_view_type=3,
output_format="list",
)
if not runs:
return None
run: Run = runs[0]
if mlflow_utils.is_zenml_run(run):
return cast(str, run.info.run_id)
else:
return None
get_step_run_metadata(self, info)
Get component- and step-specific metadata after a step ran.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that was executed. |
required |
Returns:
Type | Description |
---|---|
Dict[str, MetadataType] |
A dictionary of metadata. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_step_run_metadata(
self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
"""Get component- and step-specific metadata after a step ran.
Args:
info: Info about the step that was executed.
Returns:
A dictionary of metadata.
"""
metadata: Dict[str, Any] = {
METADATA_EXPERIMENT_TRACKER_URL: Uri(
self.get_tracking_uri(as_plain_text=False)
),
}
if run := mlflow.active_run():
metadata["mlflow_run_id"] = run.info.run_id
metadata["mlflow_experiment_id"] = run.info.experiment_id
return metadata
get_tracking_uri(self, as_plain_text=True)
Returns the configured tracking URI or a local fallback.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
as_plain_text |
bool |
Whether to return the tracking URI as plain text. |
True |
Returns:
Type | Description |
---|---|
str |
The tracking URI. |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self, as_plain_text: bool = True) -> str:
"""Returns the configured tracking URI or a local fallback.
Args:
as_plain_text: Whether to return the tracking URI as plain text.
Returns:
The tracking URI.
"""
if as_plain_text:
tracking_uri = self.config.tracking_uri
else:
tracking_uri = self.config.model_dump()["tracking_uri"]
return tracking_uri or self._local_mlflow_backend()
prepare_step_run(self, info)
Sets the MLflow tracking uri and credentials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
info |
StepRunInfo |
Info about the step that will be executed. |
required |
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
"""Sets the MLflow tracking uri and credentials.
Args:
info: Info about the step that will be executed.
"""
self.configure_mlflow()
settings = cast(
MLFlowExperimentTrackerSettings,
self.get_settings(info),
)
experiment_name = settings.experiment_name or info.pipeline.name
experiment = self._set_active_experiment(experiment_name)
run_id = self.get_run_id(
experiment_name=experiment_name, run_name=info.run_name
)
tags = settings.tags.copy()
tags.update(self._get_internal_tags())
mlflow.start_run(
run_id=run_id,
run_name=info.run_name,
experiment_id=experiment.experiment_id,
tags=tags,
)
if settings.nested:
mlflow.start_run(
run_name=info.pipeline_step_name, nested=True, tags=tags
)
flavors
special
MLFlow integration flavors.
mlflow_experiment_tracker_flavor
MLflow experiment tracker flavor.
MLFlowExperimentTrackerConfig (BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings)
Config for the MLflow experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
tracking_uri |
Optional[str] |
The uri of the mlflow tracking server. If no uri is set,
your stack must contain a |
tracking_username |
Optional[str] |
Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_password |
Optional[str] |
Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_token |
Optional[str] |
Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either |
tracking_insecure_tls |
bool |
Skips verification of TLS connection to the
MLflow tracking server if set to |
databricks_host |
Optional[str] |
The host of the Databricks workspace with the MLflow
managed server to connect to. This is only required if
|
enable_unity_catalog |
bool |
If |
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerConfig(
BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings
):
"""Config for the MLflow experiment tracker.
Attributes:
tracking_uri: The uri of the mlflow tracking server. If no uri is set,
your stack must contain a `LocalArtifactStore` and ZenML will
point MLflow to a subdirectory of your artifact store instead.
tracking_username: Username for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_password: Password for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_token: Token for authenticating with the MLflow
tracking server. When a remote tracking uri is specified,
either `tracking_token` or `tracking_username` and
`tracking_password` must be specified.
tracking_insecure_tls: Skips verification of TLS connection to the
MLflow tracking server if set to `True`.
databricks_host: The host of the Databricks workspace with the MLflow
managed server to connect to. This is only required if
`tracking_uri` value is set to `"databricks"`.
enable_unity_catalog: If `True`, will enable the Databricks Unity Catalog for
logging and registering models.
"""
tracking_uri: Optional[str] = None
tracking_username: Optional[str] = SecretField(default=None)
tracking_password: Optional[str] = SecretField(default=None)
tracking_token: Optional[str] = SecretField(default=None)
tracking_insecure_tls: bool = False
databricks_host: Optional[str] = None
enable_unity_catalog: bool = False
@model_validator(mode="after")
def _ensure_authentication_if_necessary(
self,
) -> "MLFlowExperimentTrackerConfig":
"""Ensures that credentials or a token for authentication exist.
We make this check when running MLflow tracking with a remote backend.
Returns:
The validated values.
Raises:
ValueError: If neither credentials nor a token are provided.
"""
if self.tracking_uri:
if is_databricks_tracking_uri(self.tracking_uri):
# If the tracking uri is "databricks", then we need the
# databricks host to be set.
if not self.databricks_host:
raise ValueError(
"MLflow experiment tracking with a Databricks MLflow "
"managed tracking server requires the "
"`databricks_host` to be set in your stack component. "
"To update your component, run "
"`zenml experiment-tracker update "
"<NAME> --databricks_host=DATABRICKS_HOST` "
"and specify the hostname of your Databricks workspace."
)
if is_remote_mlflow_tracking_uri(self.tracking_uri):
# we need either username + password or a token to authenticate
# to the remote backend
basic_auth = self.tracking_username and self.tracking_password
if not (basic_auth or self.tracking_token):
raise ValueError(
f"MLflow experiment tracking with a remote backend "
f"{self.tracking_uri} is only possible when specifying "
f"either username and password or an authentication "
f"token in your stack component. To update your "
f"component, run the following command: "
f"`zenml experiment-tracker update "
f"<NAME> --tracking_username=MY_USERNAME "
f"--tracking_password=MY_PASSWORD "
f"--tracking_token=MY_TOKEN` and specify either your "
f"username and password or token."
)
return self
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
Returns:
True if this config is for a local component, False otherwise.
"""
if not self.tracking_uri or not is_remote_mlflow_tracking_uri(
self.tracking_uri
):
return True
return False
is_local: bool
property
readonly
Checks if this stack component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
MLFlowExperimentTrackerFlavor (BaseExperimentTrackerFlavor)
Class for the MLFlowExperimentTrackerFlavor
.
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
"""Class for the `MLFlowExperimentTrackerFlavor`."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/mlflow.png"
@property
def config_class(self) -> Type[MLFlowExperimentTrackerConfig]:
"""Returns `MLFlowExperimentTrackerConfig` config class.
Returns:
The config class.
"""
return MLFlowExperimentTrackerConfig
@property
def implementation_class(self) -> Type["MLFlowExperimentTracker"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.experiment_trackers import (
MLFlowExperimentTracker,
)
return MLFlowExperimentTracker
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig]
property
readonly
Returns MLFlowExperimentTrackerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[MLFlowExperimentTracker]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowExperimentTracker] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
MLFlowExperimentTrackerSettings (BaseSettings)
Settings for the MLflow experiment tracker.
Attributes:
Name | Type | Description |
---|---|---|
experiment_name |
Optional[str] |
The MLflow experiment name. |
nested |
bool |
If |
tags |
Dict[str, Any] |
Tags for the Mlflow run. |
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerSettings(BaseSettings):
"""Settings for the MLflow experiment tracker.
Attributes:
experiment_name: The MLflow experiment name.
nested: If `True`, will create a nested sub-run for the step.
tags: Tags for the Mlflow run.
"""
experiment_name: Optional[str] = None
nested: bool = False
tags: Dict[str, Any] = {}
is_databricks_tracking_uri(tracking_uri)
Checks whether the given tracking uri is a Databricks tracking uri.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tracking_uri |
str |
The tracking uri to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_databricks_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is a Databricks tracking uri.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is a Databricks tracking uri, `False`
otherwise.
"""
return tracking_uri == "databricks"
is_remote_mlflow_tracking_uri(tracking_uri)
Checks whether the given tracking uri is remote or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tracking_uri |
str |
The tracking uri to check. |
required |
Returns:
Type | Description |
---|---|
bool |
|
Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool:
"""Checks whether the given tracking uri is remote or not.
Args:
tracking_uri: The tracking uri to check.
Returns:
`True` if the tracking uri is remote, `False` otherwise.
"""
return any(
tracking_uri.startswith(prefix) for prefix in ["http://", "https://"]
) or is_databricks_tracking_uri(tracking_uri)
mlflow_model_deployer_flavor
MLflow model deployer flavor.
MLFlowModelDeployerConfig (BaseModelDeployerConfig)
Configuration for the MLflow model deployer.
Attributes:
Name | Type | Description |
---|---|---|
service_path |
str |
the path where the local MLflow deployment service configuration, PID and log files are stored. |
Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerConfig(BaseModelDeployerConfig):
"""Configuration for the MLflow model deployer.
Attributes:
service_path: the path where the local MLflow deployment service
configuration, PID and log files are stored.
"""
service_path: str = ""
@property
def is_local(self) -> bool:
"""Checks if this stack component is running locally.
Returns:
True if this config is for a local component, False otherwise.
"""
return True
is_local: bool
property
readonly
Checks if this stack component is running locally.
Returns:
Type | Description |
---|---|
bool |
True if this config is for a local component, False otherwise. |
MLFlowModelDeployerFlavor (BaseModelDeployerFlavor)
Model deployer flavor for MLflow models.
Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
"""Model deployer flavor for MLflow models."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_DEPLOYER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"
@property
def config_class(self) -> Type[MLFlowModelDeployerConfig]:
"""Returns `MLFlowModelDeployerConfig` config class.
Returns:
The config class.
"""
return MLFlowModelDeployerConfig
@property
def implementation_class(self) -> Type["MLFlowModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.model_deployers import (
MLFlowModelDeployer,
)
return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]
property
readonly
Returns MLFlowModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[MLFlowModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowModelDeployer] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
mlflow_model_registry_flavor
MLflow model registry flavor.
MLFlowModelRegistryConfig (BaseModelRegistryConfig)
Configuration for the MLflow model registry.
Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryConfig(BaseModelRegistryConfig):
"""Configuration for the MLflow model registry."""
MLFlowModelRegistryFlavor (BaseModelRegistryFlavor)
Model registry flavor for MLflow models.
Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryFlavor(BaseModelRegistryFlavor):
"""Model registry flavor for MLflow models."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_REGISTRY_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"
@property
def config_class(self) -> Type[MLFlowModelRegistryConfig]:
"""Returns `MLFlowModelRegistryConfig` config class.
Returns:
The config class.
"""
return MLFlowModelRegistryConfig
@property
def implementation_class(self) -> Type["MLFlowModelRegistry"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.model_registries import (
MLFlowModelRegistry,
)
return MLFlowModelRegistry
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig]
property
readonly
Returns MLFlowModelRegistryConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[MLFlowModelRegistry]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowModelRegistry] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
mlflow_utils
Implementation of utils specific to the MLflow integration.
get_missing_mlflow_experiment_tracker_error()
Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
Type | Description |
---|---|
ValueError |
If no MLflow experiment tracker is registered in the active stack. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
"""Returns description of how to add an MLflow experiment tracker to your stack.
Returns:
ValueError: If no MLflow experiment tracker is registered in the active stack.
"""
return ValueError(
"The active stack needs to have a MLflow experiment tracker "
"component registered to be able to track experiments using "
"MLflow. You can create a new stack with a MLflow experiment "
"tracker component or update your existing stack to add this "
"component, e.g.:\n\n"
" 'zenml experiment-tracker register mlflow_tracker "
"--type=mlflow'\n"
" 'zenml stack register stack-name -e mlflow_tracker ...'\n"
)
get_tracking_uri()
Gets the MLflow tracking URI from the active experiment tracking stack component.
noqa: DAR401
Returns:
Type | Description |
---|---|
str |
MLflow tracking URI. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
"""Gets the MLflow tracking URI from the active experiment tracking stack component.
# noqa: DAR401
Returns:
MLflow tracking URI.
"""
from zenml.integrations.mlflow.experiment_trackers.mlflow_experiment_tracker import (
MLFlowExperimentTracker,
)
tracker = Client().active_stack.experiment_tracker
if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
raise get_missing_mlflow_experiment_tracker_error()
return tracker.get_tracking_uri()
is_zenml_run(run)
Checks if a MLflow run is a ZenML run or not.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
run |
mlflow.entities.Run |
The run to check. |
required |
Returns:
Type | Description |
---|---|
bool |
If the run is a ZenML run. |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def is_zenml_run(run: Run) -> bool:
"""Checks if a MLflow run is a ZenML run or not.
Args:
run: The run to check.
Returns:
If the run is a ZenML run.
"""
return ZENML_TAG_KEY in run.data.tags
stop_zenml_mlflow_runs(status)
Stops active ZenML Mlflow runs.
This function stops all MLflow active runs until no active run exists or a non-ZenML run is active.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
status |
str |
The status to set the run to. |
required |
Source code in zenml/integrations/mlflow/mlflow_utils.py
def stop_zenml_mlflow_runs(status: str) -> None:
"""Stops active ZenML Mlflow runs.
This function stops all MLflow active runs until no active run exists or
a non-ZenML run is active.
Args:
status: The status to set the run to.
"""
active_run = mlflow.active_run()
while active_run:
if is_zenml_run(active_run):
logger.debug("Stopping mlflow run %s.", active_run.info.run_id)
mlflow.end_run(status=status)
active_run = mlflow.active_run()
else:
break
model_deployers
special
Initialization of the MLflow model deployers.
mlflow_model_deployer
Implementation of the MLflow model deployer.
MLFlowModelDeployer (BaseModelDeployer)
MLflow implementation of the BaseModelDeployer.
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
"""MLflow implementation of the BaseModelDeployer."""
NAME: ClassVar[str] = "MLflow"
FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = MLFlowModelDeployerFlavor
_service_path: Optional[str] = None
@property
def config(self) -> MLFlowModelDeployerConfig:
"""Returns the `MLFlowModelDeployerConfig` config.
Returns:
The configuration.
"""
return cast(MLFlowModelDeployerConfig, self._config)
@staticmethod
def get_service_path(id_: UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files
are stored.
Args:
id_: The ID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
GlobalConfiguration().local_stores_path,
str(id_),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
@property
def local_path(self) -> str:
"""Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes
are stored.
If the service path is not set in the config by the user, the path is
set to a local default path according to the component ID.
Returns:
The path to the local service root directory.
"""
if self._service_path is not None:
return self._service_path
if self.config.service_path:
self._service_path = self.config.service_path
else:
self._service_path = self.get_service_path(self.id)
create_dir_recursive_if_not_exists(self._service_path)
return self._service_path
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
"REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
"HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
service_instance.endpoint
),
}
def perform_deploy_model(
self,
id: UUID,
config: ServiceConfig,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
id: the ID of the MLflow deployment service to be created or updated.
config: the configuration of the model to be deployed with MLflow.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = self._create_new_service(
id=id, timeout=timeout, config=config
)
logger.info(f"Created a new MLflow deployment service: {service}")
return service
def _clean_up_existing_service(
self,
timeout: int,
force: bool,
existing_service: MLFlowDeploymentService,
) -> None:
# stop the older service
existing_service.stop(timeout=timeout, force=force)
# delete the old configuration file
if existing_service.status.runtime_path:
shutil.rmtree(existing_service.status.runtime_path)
# the step will receive a config from the user that mentions the number
# of workers etc.the step implementation will create a new config using
# all values from the user and add values like pipeline name, model_uri
def _create_new_service(
self, id: UUID, timeout: int, config: MLFlowDeploymentConfig
) -> MLFlowDeploymentService:
"""Creates a new MLFlowDeploymentService.
Args:
id: the ID of the MLflow deployment service to be created or updated.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated.
config: the configuration of the model to be deployed with MLflow.
Returns:
The MLFlowDeploymentService object that can be used to interact
with the MLflow model server.
"""
# set the root runtime path with the stack component's UUID
config.root_runtime_path = self.local_path
# create a new service for the new model
service = MLFlowDeploymentService(uuid=id, config=config)
service.start(timeout=timeout)
return service
def perform_stop_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> BaseService:
"""Method to stop a model server.
Args:
service: The service to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
Returns:
The service that was stopped.
"""
service.stop(timeout=timeout, force=force)
return service
def perform_start_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Method to start a model server.
Args:
service: The service to start.
timeout: Timeout in seconds to wait for the service to start.
Returns:
The service that was started.
"""
service.start(timeout=timeout)
return service
def perform_delete_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
service: The service to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
service = cast(MLFlowDeploymentService, service)
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
config: MLFlowModelDeployerConfig
property
readonly
Returns the MLFlowModelDeployerConfig
config.
Returns:
Type | Description |
---|---|
MLFlowModelDeployerConfig |
The configuration. |
local_path: str
property
readonly
Returns the path to the root directory.
This is where all configurations for MLflow deployment daemon processes are stored.
If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.
Returns:
Type | Description |
---|---|
str |
The path to the local service root directory. |
FLAVOR (BaseModelDeployerFlavor)
Model deployer flavor for MLflow models.
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
"""Model deployer flavor for MLflow models."""
@property
def name(self) -> str:
"""Name of the flavor.
Returns:
The name of the flavor.
"""
return MLFLOW_MODEL_DEPLOYER_FLAVOR
@property
def docs_url(self) -> Optional[str]:
"""A url to point at docs explaining this flavor.
Returns:
A flavor docs url.
"""
return self.generate_default_docs_url()
@property
def sdk_docs_url(self) -> Optional[str]:
"""A url to point at SDK docs explaining this flavor.
Returns:
A flavor SDK docs url.
"""
return self.generate_default_sdk_docs_url()
@property
def logo_url(self) -> str:
"""A url to represent the flavor in the dashboard.
Returns:
The flavor logo.
"""
return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"
@property
def config_class(self) -> Type[MLFlowModelDeployerConfig]:
"""Returns `MLFlowModelDeployerConfig` config class.
Returns:
The config class.
"""
return MLFlowModelDeployerConfig
@property
def implementation_class(self) -> Type["MLFlowModelDeployer"]:
"""Implementation class for this flavor.
Returns:
The implementation class.
"""
from zenml.integrations.mlflow.model_deployers import (
MLFlowModelDeployer,
)
return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]
property
readonly
Returns MLFlowModelDeployerConfig
config class.
Returns:
Type | Description |
---|---|
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] |
The config class. |
docs_url: Optional[str]
property
readonly
A url to point at docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor docs url. |
implementation_class: Type[MLFlowModelDeployer]
property
readonly
Implementation class for this flavor.
Returns:
Type | Description |
---|---|
Type[MLFlowModelDeployer] |
The implementation class. |
logo_url: str
property
readonly
A url to represent the flavor in the dashboard.
Returns:
Type | Description |
---|---|
str |
The flavor logo. |
name: str
property
readonly
Name of the flavor.
Returns:
Type | Description |
---|---|
str |
The name of the flavor. |
sdk_docs_url: Optional[str]
property
readonly
A url to point at SDK docs explaining this flavor.
Returns:
Type | Description |
---|---|
Optional[str] |
A flavor SDK docs url. |
get_model_server_info(service_instance)
staticmethod
Return implementation specific information relevant to the user.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service_instance |
MLFlowDeploymentService |
Instance of a SeldonDeploymentService |
required |
Returns:
Type | Description |
---|---|
Dict[str, Optional[str]] |
A dictionary containing the information. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info( # type: ignore[override]
service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
"""Return implementation specific information relevant to the user.
Args:
service_instance: Instance of a SeldonDeploymentService
Returns:
A dictionary containing the information.
"""
return {
"PREDICTION_URL": service_instance.endpoint.prediction_url,
"MODEL_URI": service_instance.config.model_uri,
"MODEL_NAME": service_instance.config.model_name,
"REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
"REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
"SERVICE_PATH": service_instance.status.runtime_path,
"DAEMON_PID": str(service_instance.status.pid),
"HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
service_instance.endpoint
),
}
get_service_path(id_)
staticmethod
Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files are stored.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id_ |
UUID |
The ID of the MLflow model deployer. |
required |
Returns:
Type | Description |
---|---|
str |
The service path. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(id_: UUID) -> str:
"""Get the path where local MLflow service information is stored.
This includes the deployment service configuration, PID and log files
are stored.
Args:
id_: The ID of the MLflow model deployer.
Returns:
The service path.
"""
service_path = os.path.join(
GlobalConfiguration().local_stores_path,
str(id_),
)
create_dir_recursive_if_not_exists(service_path)
return service_path
perform_delete_model(self, service, timeout=60, force=False)
Method to delete all configuration of a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to delete. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
60 |
force |
bool |
If True, force the service to stop. |
False |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_delete_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> None:
"""Method to delete all configuration of a model server.
Args:
service: The service to delete.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
"""
service = cast(MLFlowDeploymentService, service)
self._clean_up_existing_service(
existing_service=service, timeout=timeout, force=force
)
perform_deploy_model(self, id, config, timeout=60)
Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the replace
argument value:
-
if
replace
is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow serviceconfig
. -
if
replace
is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the samepipeline_name
,pipeline_step_name
andmodel_name
configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.
Callers should set replace
to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
UUID |
the ID of the MLflow deployment service to be created or updated. |
required |
config |
ServiceConfig |
the configuration of the model to be deployed with MLflow. |
required |
timeout |
int |
the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start. |
60 |
Returns:
Type | Description |
---|---|
BaseService |
The ZenML MLflow deployment service object that can be used to interact with the MLflow model server. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_deploy_model(
self,
id: UUID,
config: ServiceConfig,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Create a new MLflow deployment service or update an existing one.
This should serve the supplied model and deployment configuration.
This method has two modes of operation, depending on the `replace`
argument value:
* if `replace` is False, calling this method will create a new MLflow
deployment server to reflect the model and other configuration
parameters specified in the supplied MLflow service `config`.
* if `replace` is True, this method will first attempt to find an
existing MLflow deployment service that is *equivalent* to the
supplied configuration parameters. Two or more MLflow deployment
services are considered equivalent if they have the same
`pipeline_name`, `pipeline_step_name` and `model_name` configuration
parameters. To put it differently, two MLflow deployment services
are equivalent if they serve versions of the same model deployed by
the same pipeline step. If an equivalent MLflow deployment is found,
it will be updated in place to reflect the new configuration
parameters.
Callers should set `replace` to True if they want a continuous model
deployment workflow that doesn't spin up a new MLflow deployment
server for each new model version. If multiple equivalent MLflow
deployment servers are found, one is selected at random to be updated
and the others are deleted.
Args:
id: the ID of the MLflow deployment service to be created or updated.
config: the configuration of the model to be deployed with MLflow.
timeout: the timeout in seconds to wait for the MLflow server
to be provisioned and successfully started or updated. If set
to 0, the method will return immediately after the MLflow
server is provisioned, without waiting for it to fully start.
Returns:
The ZenML MLflow deployment service object that can be used to
interact with the MLflow model server.
"""
config = cast(MLFlowDeploymentConfig, config)
service = self._create_new_service(
id=id, timeout=timeout, config=config
)
logger.info(f"Created a new MLflow deployment service: {service}")
return service
perform_start_model(self, service, timeout=60)
Method to start a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to start. |
required |
timeout |
int |
Timeout in seconds to wait for the service to start. |
60 |
Returns:
Type | Description |
---|---|
BaseService |
The service that was started. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_start_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
"""Method to start a model server.
Args:
service: The service to start.
timeout: Timeout in seconds to wait for the service to start.
Returns:
The service that was started.
"""
service.start(timeout=timeout)
return service
perform_stop_model(self, service, timeout=60, force=False)
Method to stop a model server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
service |
BaseService |
The service to stop. |
required |
timeout |
int |
Timeout in seconds to wait for the service to stop. |
60 |
force |
bool |
If True, force the service to stop. |
False |
Returns:
Type | Description |
---|---|
BaseService |
The service that was stopped. |
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_stop_model(
self,
service: BaseService,
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
force: bool = False,
) -> BaseService:
"""Method to stop a model server.
Args:
service: The service to stop.
timeout: Timeout in seconds to wait for the service to stop.
force: If True, force the service to stop.
Returns:
The service that was stopped.
"""
service.stop(timeout=timeout, force=force)
return service
model_registries
special
Initialization of the MLflow model registry.
mlflow_model_registry
Implementation of the MLflow model registry for ZenML.
MLFlowModelRegistry (BaseModelRegistry)
Register models using MLflow.
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
class MLFlowModelRegistry(BaseModelRegistry):
"""Register models using MLflow."""
_client: Optional[MlflowClient] = None
@property
def config(self) -> MLFlowModelRegistryConfig:
"""Returns the `MLFlowModelRegistryConfig` config.
Returns:
The configuration.
"""
return cast(MLFlowModelRegistryConfig, self._config)
def configure_mlflow(self) -> None:
"""Configures the MLflow Client with the experiment tracker config."""
experiment_tracker = Client().active_stack.experiment_tracker
assert isinstance(experiment_tracker, MLFlowExperimentTracker)
experiment_tracker.configure_mlflow()
@property
def mlflow_client(self) -> MlflowClient:
"""Get the MLflow client.
Returns:
The MLFlowClient.
"""
if not self._client:
self.configure_mlflow()
self._client = mlflow.tracking.MlflowClient()
return self._client
@property
def validator(self) -> Optional[StackValidator]:
"""Validates that the stack contains an mlflow experiment tracker.
Returns:
A StackValidator instance.
"""
def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
"""Validates that all the requirements are met for the stack.
Args:
stack: The stack to validate.
Returns:
A tuple of (is_valid, error_message).
"""
# Validate that the experiment tracker is an mlflow experiment tracker.
experiment_tracker = stack.experiment_tracker
assert experiment_tracker is not None
if experiment_tracker.flavor != "mlflow":
return False, (
"The MLflow model registry requires a MLflow experiment "
"tracker. You should register a MLflow experiment "
"tracker to the stack using the following command: "
"`zenml stack update model_registry -e mlflow_tracker"
)
mlflow_version = mlflow.version.VERSION
if (
not mlflow_version >= "2.1.1"
and experiment_tracker.config.is_local
):
return False, (
"The MLflow model registry requires MLflow version "
f"2.1.1 or higher to use a local MLflow registry. "
f"Your current MLflow version is {mlflow_version}."
"You can upgrade MLflow using the following command: "
"`pip install --upgrade mlflow`"
)
return True, ""
return StackValidator(
required_components={
StackComponentType.EXPERIMENT_TRACKER,
},
custom_validation_function=_validate_stack_requirements,
)
# ---------
# Model Registration Methods
# ---------
def register_model(
self,
name: str,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
"""Register a model to the MLflow model registry.
Args:
name: The name of the model.
description: The description of the model.
metadata: The metadata of the model.
Raises:
RuntimeError: If the model already exists.
Returns:
The registered model.
"""
# Check if model already exists.
try:
self.get_model(name)
raise KeyError(
f"Model with name {name} already exists in the MLflow model "
f"registry. Please use a different name.",
)
except KeyError:
pass
# Register model.
try:
registered_model = self.mlflow_client.create_registered_model(
name=name,
description=description,
tags=metadata,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to register model with name {name} to the MLflow "
f"model registry: {str(e)}",
)
# Return the registered model.
return RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
def delete_model(
self,
name: str,
) -> None:
"""Delete a model from the MLflow model registry.
Args:
name: The name of the model.
Raises:
RuntimeError: If the model does not exist.
"""
# Check if model exists.
self.get_model(name=name)
# Delete the registered model.
try:
self.mlflow_client.delete_registered_model(
name=name,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to delete model with name {name} from MLflow model "
f"registry: {str(e)}",
)
def update_model(
self,
name: str,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
"""Update a model in the MLflow model registry.
Args:
name: The name of the model.
description: The description of the model.
metadata: The metadata of the model.
remove_metadata: The metadata to remove from the model.
Raises:
RuntimeError: If mlflow fails to update the model.
Returns:
The updated model.
"""
# Check if model exists.
self.get_model(name=name)
# Update the registered model description.
if description:
try:
self.mlflow_client.update_registered_model(
name=name,
description=description,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update description for the model {name} in MLflow "
f"model registry: {str(e)}",
)
# Update the registered model tags.
if metadata:
try:
for tag, value in metadata.items():
self.mlflow_client.set_registered_model_tag(
name=name,
key=tag,
value=value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update tags for the model {name} in MLflow model "
f"registry: {str(e)}",
)
# Remove tags from the registered model.
if remove_metadata:
try:
for tag in remove_metadata:
self.mlflow_client.delete_registered_model_tag(
name=name,
key=tag,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to remove tags for the model {name} in MLflow model "
f"registry: {str(e)}",
)
# Return the updated registered model.
return self.get_model(name)
def get_model(self, name: str) -> RegisteredModel:
"""Get a model from the MLflow model registry.
Args:
name: The name of the model.
Returns:
The model.
Raises:
KeyError: If mlflow fails to get the model.
"""
# Get the registered model.
try:
registered_model = self.mlflow_client.get_registered_model(
name=name,
)
except MlflowException as e:
raise KeyError(
f"Failed to get model with name {name} from the MLflow model "
f"registry: {str(e)}",
)
# Return the registered model.
return RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
def list_models(
self,
name: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
"""List models in the MLflow model registry.
Args:
name: A name to filter the models by.
metadata: The metadata to filter the models by.
Returns:
A list of models (RegisteredModel)
"""
# Set the filter string.
filter_string = ""
if name:
filter_string += f"name='{name}'"
if metadata:
for tag, value in metadata.items():
if filter_string:
filter_string += " AND "
filter_string += f"tags.{tag}='{value}'"
# Get the registered models.
registered_models = self.mlflow_client.search_registered_models(
filter_string=filter_string,
max_results=100,
)
# Return the registered models.
return [
RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
for registered_model in registered_models
]
# ---------
# Model Version Methods
# ---------
def register_model_version(
self,
name: str,
version: Optional[str] = None,
model_source_uri: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
**kwargs: Any,
) -> RegistryModelVersion:
"""Register a model version to the MLflow model registry.
Args:
name: The name of the model.
model_source_uri: The source URI of the model.
version: The version of the model.
description: The description of the model version.
metadata: The registry metadata of the model version.
**kwargs: Additional keyword arguments.
Raises:
RuntimeError: If the registered model does not exist.
ValueError: If no model source URI was provided.
Returns:
The registered model version.
"""
if not model_source_uri:
raise ValueError(
"Unable to register model version without model source URI."
)
# Check if the model exists, if not create it.
try:
self.get_model(name=name)
except KeyError:
logger.info(
f"No registered model with name {name} found. Creating a new "
"registered model."
)
self.register_model(
name=name,
)
try:
# Inform the user that the version is ignored.
if version:
logger.info(
f"MLflow model registry does not take a version as an argument. "
f"Registering a new version for the model `'{name}'` "
f"a version will be assigned automatically."
)
metadata_dict = metadata.model_dump() if metadata else {}
# Set the run ID and link.
run_id = metadata_dict.get("mlflow_run_id", None)
run_link = metadata_dict.get("mlflow_run_link", None)
# Register the model version.
registered_model_version = self.mlflow_client.create_model_version(
name=name,
source=model_source_uri,
run_id=run_id,
run_link=run_link,
description=description,
tags=metadata_dict,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to register model version with name '{name}' and "
f"version '{version}' to the MLflow model registry."
f"Error: {e}"
)
# Return the registered model version.
return self._cast_mlflow_version_to_model_version(
registered_model_version
)
def delete_model_version(
self,
name: str,
version: str,
) -> None:
"""Delete a model version from the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
Raises:
RuntimeError: If mlflow fails to delete the model version.
"""
self.get_model_version(name=name, version=version)
try:
self.mlflow_client.delete_model_version(
name=name,
version=version,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to delete model version '{version}' of model '{name}'."
f"From the MLflow model registry: {str(e)}",
)
def update_model_version(
self,
name: str,
version: str,
description: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
remove_metadata: Optional[List[str]] = None,
stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
"""Update a model version in the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
description: The description of the model version.
metadata: The metadata of the model version.
remove_metadata: The metadata to remove from the model version.
stage: The stage of the model version.
Raises:
RuntimeError: If mlflow fails to update the model version.
Returns:
The updated model version.
"""
self.get_model_version(name=name, version=version)
# Update the model description.
if description:
try:
self.mlflow_client.update_model_version(
name=name,
version=version,
description=description,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the description of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Update the model tags.
if metadata:
try:
for key, value in metadata.model_dump().items():
self.mlflow_client.set_model_version_tag(
name=name,
version=version,
key=key,
value=value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the tags of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Remove the model tags.
if remove_metadata:
try:
for key in remove_metadata:
self.mlflow_client.delete_model_version_tag(
name=name,
version=version,
key=key,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to remove the tags of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Update the model stage.
if stage:
try:
self.mlflow_client.transition_model_version_stage(
name=name,
version=version,
stage=stage.value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the current stage of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
return self.get_model_version(name, version)
def get_model_version(
self,
name: str,
version: str,
) -> RegistryModelVersion:
"""Get a model version from the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
Raises:
KeyError: If the model version does not exist.
Returns:
The model version.
"""
# Get the model version from the MLflow model registry.
try:
mlflow_model_version = self.mlflow_client.get_model_version(
name=name,
version=version,
)
except MlflowException as e:
raise KeyError(
f"Failed to get model version '{name}:{version}' from the "
f"MLflow model registry: {str(e)}"
)
# Return the model version.
return self._cast_mlflow_version_to_model_version(
mlflow_model_version=mlflow_model_version,
)
def list_model_versions(
self,
name: Optional[str] = None,
model_source_uri: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
stage: Optional[ModelVersionStage] = None,
count: Optional[int] = None,
created_after: Optional[datetime] = None,
created_before: Optional[datetime] = None,
order_by_date: Optional[str] = None,
**kwargs: Any,
) -> List[RegistryModelVersion]:
"""List model versions from the MLflow model registry.
Args:
name: The name of the model.
model_source_uri: The model source URI.
metadata: The metadata of the model version.
stage: The stage of the model version.
count: The maximum number of model versions to return.
created_after: The minimum creation time of the model versions.
created_before: The maximum creation time of the model versions.
order_by_date: The order of the model versions by creation time,
either ascending or descending.
kwargs: Additional keyword arguments.
Returns:
The model versions.
"""
# Set the filter string.
filter_string = ""
if name:
filter_string += f"name='{name}'"
if model_source_uri:
if filter_string:
filter_string += " AND "
filter_string += f"source='{model_source_uri}'"
if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
if filter_string:
filter_string += " AND "
filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
if metadata:
for tag, value in metadata.model_dump().items():
if value:
if filter_string:
filter_string += " AND "
filter_string += f"tags.{tag}='{value}'"
# Get the model versions.
order_by = []
if order_by_date:
if order_by_date in ["asc", "desc"]:
if order_by_date == "asc":
order_by = ["creation_timestamp ASC"]
else:
order_by = ["creation_timestamp DESC"]
mlflow_model_versions = self.mlflow_client.search_model_versions(
filter_string=filter_string,
order_by=order_by,
)
# Cast the MLflow model versions to the ZenML model version class.
model_versions = []
for mlflow_model_version in mlflow_model_versions:
# check if given MlFlow model version matches the given request
# before casting it
if (
stage
and not ModelVersionStage(mlflow_model_version.current_stage)
== stage
):
continue
if created_after and not (
mlflow_model_version.creation_timestamp
>= created_after.timestamp()
):
continue
if created_before and not (
mlflow_model_version.creation_timestamp
<= created_before.timestamp()
):
continue
try:
model_versions.append(
self._cast_mlflow_version_to_model_version(
mlflow_model_version=mlflow_model_version,
)
)
except (AttributeError, OSError) as e:
# Sometimes, the Model Registry in MLflow can become unusable
# due to failed version registration or misuse. In such rare
# cases, it's best to suppress those versions that are not usable.
logger.warning(
"Error encountered while loading MLflow model version "
f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
)
if count and len(model_versions) == count:
return model_versions
return model_versions
def load_model_version(
self,
name: str,
version: str,
**kwargs: Any,
) -> Any:
"""Load a model version from the MLflow model registry.
This method loads the model version from the MLflow model registry
and returns the model. The model is loaded using the `mlflow.pyfunc`
module which takes care of loading the model from the model source
URI for the right framework.
Args:
name: The name of the model.
version: The version of the model.
kwargs: Additional keyword arguments.
Returns:
The model version.
Raises:
KeyError: If the model version does not exist.
"""
try:
self.get_model_version(name=name, version=version)
except KeyError:
raise KeyError(
f"Failed to load model version '{name}:{version}' from the "
f"MLflow model registry: Model version does not exist."
)
# Load the model version.
mlflow_model_version = self.mlflow_client.get_model_version(
name=name,
version=version,
)
return load_model(
model_uri=mlflow_model_version.source,
**kwargs,
)
def get_model_uri_artifact_store(
self,
model_version: RegistryModelVersion,
) -> str:
"""Get the model URI artifact store.
Args:
model_version: The model version.
Returns:
The model URI artifact store.
"""
artifact_store_path = (
f"{Client().active_stack.artifact_store.path}/mlflow"
)
model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
return artifact_store_path + model_source_uri
def _cast_mlflow_version_to_model_version(
self,
mlflow_model_version: MLflowModelVersion,
) -> RegistryModelVersion:
"""Cast an MLflow model version to a model version.
Args:
mlflow_model_version: The MLflow model version.
Returns:
The model version.
"""
metadata = mlflow_model_version.tags or {}
if mlflow_model_version.run_id:
metadata["mlflow_run_id"] = mlflow_model_version.run_id
if mlflow_model_version.run_link:
metadata["mlflow_run_link"] = mlflow_model_version.run_link
try:
from mlflow.models import get_model_info
model_library = (
get_model_info(model_uri=mlflow_model_version.source)
.flavors.get("python_function", {})
.get("loader_module")
)
except ImportError:
model_library = None
return RegistryModelVersion(
registered_model=RegisteredModel(name=mlflow_model_version.name),
model_format=MLFLOW_MODEL_FORMAT,
model_library=model_library,
version=str(mlflow_model_version.version),
created_at=datetime.fromtimestamp(
int(mlflow_model_version.creation_timestamp) / 1e3
),
stage=ModelVersionStage(mlflow_model_version.current_stage),
description=mlflow_model_version.description,
last_updated_at=datetime.fromtimestamp(
int(mlflow_model_version.last_updated_timestamp) / 1e3
),
metadata=ModelRegistryModelMetadata(**metadata),
model_source_uri=mlflow_model_version.source,
)
config: MLFlowModelRegistryConfig
property
readonly
Returns the MLFlowModelRegistryConfig
config.
Returns:
Type | Description |
---|---|
MLFlowModelRegistryConfig |
The configuration. |
mlflow_client: mlflow.MlflowClient
property
readonly
Get the MLflow client.
Returns:
Type | Description |
---|---|
mlflow.MlflowClient |
The MLFlowClient. |
validator: Optional[zenml.stack.stack_validator.StackValidator]
property
readonly
Validates that the stack contains an mlflow experiment tracker.
Returns:
Type | Description |
---|---|
Optional[zenml.stack.stack_validator.StackValidator] |
A StackValidator instance. |
configure_mlflow(self)
Configures the MLflow Client with the experiment tracker config.
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def configure_mlflow(self) -> None:
"""Configures the MLflow Client with the experiment tracker config."""
experiment_tracker = Client().active_stack.experiment_tracker
assert isinstance(experiment_tracker, MLFlowExperimentTracker)
experiment_tracker.configure_mlflow()
delete_model(self, name)
Delete a model from the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the model does not exist. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model(
self,
name: str,
) -> None:
"""Delete a model from the MLflow model registry.
Args:
name: The name of the model.
Raises:
RuntimeError: If the model does not exist.
"""
# Check if model exists.
self.get_model(name=name)
# Delete the registered model.
try:
self.mlflow_client.delete_registered_model(
name=name,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to delete model with name {name} from MLflow model "
f"registry: {str(e)}",
)
delete_model_version(self, name, version)
Delete a model version from the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
version |
str |
The version of the model. |
required |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If mlflow fails to delete the model version. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model_version(
self,
name: str,
version: str,
) -> None:
"""Delete a model version from the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
Raises:
RuntimeError: If mlflow fails to delete the model version.
"""
self.get_model_version(name=name, version=version)
try:
self.mlflow_client.delete_model_version(
name=name,
version=version,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to delete model version '{version}' of model '{name}'."
f"From the MLflow model registry: {str(e)}",
)
get_model(self, name)
Get a model from the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
Returns:
Type | Description |
---|---|
RegisteredModel |
The model. |
Exceptions:
Type | Description |
---|---|
KeyError |
If mlflow fails to get the model. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model(self, name: str) -> RegisteredModel:
"""Get a model from the MLflow model registry.
Args:
name: The name of the model.
Returns:
The model.
Raises:
KeyError: If mlflow fails to get the model.
"""
# Get the registered model.
try:
registered_model = self.mlflow_client.get_registered_model(
name=name,
)
except MlflowException as e:
raise KeyError(
f"Failed to get model with name {name} from the MLflow model "
f"registry: {str(e)}",
)
# Return the registered model.
return RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
get_model_uri_artifact_store(self, model_version)
Get the model URI artifact store.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_version |
RegistryModelVersion |
The model version. |
required |
Returns:
Type | Description |
---|---|
str |
The model URI artifact store. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_uri_artifact_store(
self,
model_version: RegistryModelVersion,
) -> str:
"""Get the model URI artifact store.
Args:
model_version: The model version.
Returns:
The model URI artifact store.
"""
artifact_store_path = (
f"{Client().active_stack.artifact_store.path}/mlflow"
)
model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
return artifact_store_path + model_source_uri
get_model_version(self, name, version)
Get a model version from the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
version |
str |
The version of the model. |
required |
Exceptions:
Type | Description |
---|---|
KeyError |
If the model version does not exist. |
Returns:
Type | Description |
---|---|
RegistryModelVersion |
The model version. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_version(
self,
name: str,
version: str,
) -> RegistryModelVersion:
"""Get a model version from the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
Raises:
KeyError: If the model version does not exist.
Returns:
The model version.
"""
# Get the model version from the MLflow model registry.
try:
mlflow_model_version = self.mlflow_client.get_model_version(
name=name,
version=version,
)
except MlflowException as e:
raise KeyError(
f"Failed to get model version '{name}:{version}' from the "
f"MLflow model registry: {str(e)}"
)
# Return the model version.
return self._cast_mlflow_version_to_model_version(
mlflow_model_version=mlflow_model_version,
)
list_model_versions(self, name=None, model_source_uri=None, metadata=None, stage=None, count=None, created_after=None, created_before=None, order_by_date=None, **kwargs)
List model versions from the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
The name of the model. |
None |
model_source_uri |
Optional[str] |
The model source URI. |
None |
metadata |
Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata] |
The metadata of the model version. |
None |
stage |
Optional[zenml.model_registries.base_model_registry.ModelVersionStage] |
The stage of the model version. |
None |
count |
Optional[int] |
The maximum number of model versions to return. |
None |
created_after |
Optional[datetime.datetime] |
The minimum creation time of the model versions. |
None |
created_before |
Optional[datetime.datetime] |
The maximum creation time of the model versions. |
None |
order_by_date |
Optional[str] |
The order of the model versions by creation time, either ascending or descending. |
None |
kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
List[zenml.model_registries.base_model_registry.RegistryModelVersion] |
The model versions. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_model_versions(
self,
name: Optional[str] = None,
model_source_uri: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
stage: Optional[ModelVersionStage] = None,
count: Optional[int] = None,
created_after: Optional[datetime] = None,
created_before: Optional[datetime] = None,
order_by_date: Optional[str] = None,
**kwargs: Any,
) -> List[RegistryModelVersion]:
"""List model versions from the MLflow model registry.
Args:
name: The name of the model.
model_source_uri: The model source URI.
metadata: The metadata of the model version.
stage: The stage of the model version.
count: The maximum number of model versions to return.
created_after: The minimum creation time of the model versions.
created_before: The maximum creation time of the model versions.
order_by_date: The order of the model versions by creation time,
either ascending or descending.
kwargs: Additional keyword arguments.
Returns:
The model versions.
"""
# Set the filter string.
filter_string = ""
if name:
filter_string += f"name='{name}'"
if model_source_uri:
if filter_string:
filter_string += " AND "
filter_string += f"source='{model_source_uri}'"
if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
if filter_string:
filter_string += " AND "
filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
if metadata:
for tag, value in metadata.model_dump().items():
if value:
if filter_string:
filter_string += " AND "
filter_string += f"tags.{tag}='{value}'"
# Get the model versions.
order_by = []
if order_by_date:
if order_by_date in ["asc", "desc"]:
if order_by_date == "asc":
order_by = ["creation_timestamp ASC"]
else:
order_by = ["creation_timestamp DESC"]
mlflow_model_versions = self.mlflow_client.search_model_versions(
filter_string=filter_string,
order_by=order_by,
)
# Cast the MLflow model versions to the ZenML model version class.
model_versions = []
for mlflow_model_version in mlflow_model_versions:
# check if given MlFlow model version matches the given request
# before casting it
if (
stage
and not ModelVersionStage(mlflow_model_version.current_stage)
== stage
):
continue
if created_after and not (
mlflow_model_version.creation_timestamp
>= created_after.timestamp()
):
continue
if created_before and not (
mlflow_model_version.creation_timestamp
<= created_before.timestamp()
):
continue
try:
model_versions.append(
self._cast_mlflow_version_to_model_version(
mlflow_model_version=mlflow_model_version,
)
)
except (AttributeError, OSError) as e:
# Sometimes, the Model Registry in MLflow can become unusable
# due to failed version registration or misuse. In such rare
# cases, it's best to suppress those versions that are not usable.
logger.warning(
"Error encountered while loading MLflow model version "
f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
)
if count and len(model_versions) == count:
return model_versions
return model_versions
list_models(self, name=None, metadata=None)
List models in the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
Optional[str] |
A name to filter the models by. |
None |
metadata |
Optional[Dict[str, str]] |
The metadata to filter the models by. |
None |
Returns:
Type | Description |
---|---|
List[zenml.model_registries.base_model_registry.RegisteredModel] |
A list of models (RegisteredModel) |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_models(
self,
name: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
"""List models in the MLflow model registry.
Args:
name: A name to filter the models by.
metadata: The metadata to filter the models by.
Returns:
A list of models (RegisteredModel)
"""
# Set the filter string.
filter_string = ""
if name:
filter_string += f"name='{name}'"
if metadata:
for tag, value in metadata.items():
if filter_string:
filter_string += " AND "
filter_string += f"tags.{tag}='{value}'"
# Get the registered models.
registered_models = self.mlflow_client.search_registered_models(
filter_string=filter_string,
max_results=100,
)
# Return the registered models.
return [
RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
for registered_model in registered_models
]
load_model_version(self, name, version, **kwargs)
Load a model version from the MLflow model registry.
This method loads the model version from the MLflow model registry
and returns the model. The model is loaded using the mlflow.pyfunc
module which takes care of loading the model from the model source
URI for the right framework.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
version |
str |
The version of the model. |
required |
kwargs |
Any |
Additional keyword arguments. |
{} |
Returns:
Type | Description |
---|---|
Any |
The model version. |
Exceptions:
Type | Description |
---|---|
KeyError |
If the model version does not exist. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def load_model_version(
self,
name: str,
version: str,
**kwargs: Any,
) -> Any:
"""Load a model version from the MLflow model registry.
This method loads the model version from the MLflow model registry
and returns the model. The model is loaded using the `mlflow.pyfunc`
module which takes care of loading the model from the model source
URI for the right framework.
Args:
name: The name of the model.
version: The version of the model.
kwargs: Additional keyword arguments.
Returns:
The model version.
Raises:
KeyError: If the model version does not exist.
"""
try:
self.get_model_version(name=name, version=version)
except KeyError:
raise KeyError(
f"Failed to load model version '{name}:{version}' from the "
f"MLflow model registry: Model version does not exist."
)
# Load the model version.
mlflow_model_version = self.mlflow_client.get_model_version(
name=name,
version=version,
)
return load_model(
model_uri=mlflow_model_version.source,
**kwargs,
)
register_model(self, name, description=None, metadata=None)
Register a model to the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
description |
Optional[str] |
The description of the model. |
None |
metadata |
Optional[Dict[str, str]] |
The metadata of the model. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the model already exists. |
Returns:
Type | Description |
---|---|
RegisteredModel |
The registered model. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model(
self,
name: str,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
"""Register a model to the MLflow model registry.
Args:
name: The name of the model.
description: The description of the model.
metadata: The metadata of the model.
Raises:
RuntimeError: If the model already exists.
Returns:
The registered model.
"""
# Check if model already exists.
try:
self.get_model(name)
raise KeyError(
f"Model with name {name} already exists in the MLflow model "
f"registry. Please use a different name.",
)
except KeyError:
pass
# Register model.
try:
registered_model = self.mlflow_client.create_registered_model(
name=name,
description=description,
tags=metadata,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to register model with name {name} to the MLflow "
f"model registry: {str(e)}",
)
# Return the registered model.
return RegisteredModel(
name=registered_model.name,
description=registered_model.description,
metadata=registered_model.tags,
)
register_model_version(self, name, version=None, model_source_uri=None, description=None, metadata=None, **kwargs)
Register a model version to the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
model_source_uri |
Optional[str] |
The source URI of the model. |
None |
version |
Optional[str] |
The version of the model. |
None |
description |
Optional[str] |
The description of the model version. |
None |
metadata |
Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata] |
The registry metadata of the model version. |
None |
**kwargs |
Any |
Additional keyword arguments. |
{} |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If the registered model does not exist. |
ValueError |
If no model source URI was provided. |
Returns:
Type | Description |
---|---|
RegistryModelVersion |
The registered model version. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model_version(
self,
name: str,
version: Optional[str] = None,
model_source_uri: Optional[str] = None,
description: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
**kwargs: Any,
) -> RegistryModelVersion:
"""Register a model version to the MLflow model registry.
Args:
name: The name of the model.
model_source_uri: The source URI of the model.
version: The version of the model.
description: The description of the model version.
metadata: The registry metadata of the model version.
**kwargs: Additional keyword arguments.
Raises:
RuntimeError: If the registered model does not exist.
ValueError: If no model source URI was provided.
Returns:
The registered model version.
"""
if not model_source_uri:
raise ValueError(
"Unable to register model version without model source URI."
)
# Check if the model exists, if not create it.
try:
self.get_model(name=name)
except KeyError:
logger.info(
f"No registered model with name {name} found. Creating a new "
"registered model."
)
self.register_model(
name=name,
)
try:
# Inform the user that the version is ignored.
if version:
logger.info(
f"MLflow model registry does not take a version as an argument. "
f"Registering a new version for the model `'{name}'` "
f"a version will be assigned automatically."
)
metadata_dict = metadata.model_dump() if metadata else {}
# Set the run ID and link.
run_id = metadata_dict.get("mlflow_run_id", None)
run_link = metadata_dict.get("mlflow_run_link", None)
# Register the model version.
registered_model_version = self.mlflow_client.create_model_version(
name=name,
source=model_source_uri,
run_id=run_id,
run_link=run_link,
description=description,
tags=metadata_dict,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to register model version with name '{name}' and "
f"version '{version}' to the MLflow model registry."
f"Error: {e}"
)
# Return the registered model version.
return self._cast_mlflow_version_to_model_version(
registered_model_version
)
update_model(self, name, description=None, metadata=None, remove_metadata=None)
Update a model in the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
description |
Optional[str] |
The description of the model. |
None |
metadata |
Optional[Dict[str, str]] |
The metadata of the model. |
None |
remove_metadata |
Optional[List[str]] |
The metadata to remove from the model. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If mlflow fails to update the model. |
Returns:
Type | Description |
---|---|
RegisteredModel |
The updated model. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model(
self,
name: str,
description: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
"""Update a model in the MLflow model registry.
Args:
name: The name of the model.
description: The description of the model.
metadata: The metadata of the model.
remove_metadata: The metadata to remove from the model.
Raises:
RuntimeError: If mlflow fails to update the model.
Returns:
The updated model.
"""
# Check if model exists.
self.get_model(name=name)
# Update the registered model description.
if description:
try:
self.mlflow_client.update_registered_model(
name=name,
description=description,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update description for the model {name} in MLflow "
f"model registry: {str(e)}",
)
# Update the registered model tags.
if metadata:
try:
for tag, value in metadata.items():
self.mlflow_client.set_registered_model_tag(
name=name,
key=tag,
value=value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update tags for the model {name} in MLflow model "
f"registry: {str(e)}",
)
# Remove tags from the registered model.
if remove_metadata:
try:
for tag in remove_metadata:
self.mlflow_client.delete_registered_model_tag(
name=name,
key=tag,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to remove tags for the model {name} in MLflow model "
f"registry: {str(e)}",
)
# Return the updated registered model.
return self.get_model(name)
update_model_version(self, name, version, description=None, metadata=None, remove_metadata=None, stage=None)
Update a model version in the MLflow model registry.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str |
The name of the model. |
required |
version |
str |
The version of the model. |
required |
description |
Optional[str] |
The description of the model version. |
None |
metadata |
Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata] |
The metadata of the model version. |
None |
remove_metadata |
Optional[List[str]] |
The metadata to remove from the model version. |
None |
stage |
Optional[zenml.model_registries.base_model_registry.ModelVersionStage] |
The stage of the model version. |
None |
Exceptions:
Type | Description |
---|---|
RuntimeError |
If mlflow fails to update the model version. |
Returns:
Type | Description |
---|---|
RegistryModelVersion |
The updated model version. |
Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model_version(
self,
name: str,
version: str,
description: Optional[str] = None,
metadata: Optional[ModelRegistryModelMetadata] = None,
remove_metadata: Optional[List[str]] = None,
stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
"""Update a model version in the MLflow model registry.
Args:
name: The name of the model.
version: The version of the model.
description: The description of the model version.
metadata: The metadata of the model version.
remove_metadata: The metadata to remove from the model version.
stage: The stage of the model version.
Raises:
RuntimeError: If mlflow fails to update the model version.
Returns:
The updated model version.
"""
self.get_model_version(name=name, version=version)
# Update the model description.
if description:
try:
self.mlflow_client.update_model_version(
name=name,
version=version,
description=description,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the description of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Update the model tags.
if metadata:
try:
for key, value in metadata.model_dump().items():
self.mlflow_client.set_model_version_tag(
name=name,
version=version,
key=key,
value=value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the tags of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Remove the model tags.
if remove_metadata:
try:
for key in remove_metadata:
self.mlflow_client.delete_model_version_tag(
name=name,
version=version,
key=key,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to remove the tags of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
# Update the model stage.
if stage:
try:
self.mlflow_client.transition_model_version_stage(
name=name,
version=version,
stage=stage.value,
)
except MlflowException as e:
raise RuntimeError(
f"Failed to update the current stage of model version "
f"'{name}:{version}' in the MLflow model registry: {str(e)}"
)
return self.get_model_version(name, version)
services
special
Initialization of the MLflow Service.
mlflow_deployment
Implementation of the MLflow deployment functionality.
MLFlowDeploymentConfig (LocalDaemonServiceConfig)
MLflow model deployment configuration.
Attributes:
Name | Type | Description |
---|---|---|
model_uri |
str |
URI of the MLflow model to serve |
model_name |
str |
the name of the model |
workers |
int |
number of workers to use for the prediction service |
registry_model_name |
Optional[str] |
the name of the model in the registry |
registry_model_version |
Optional[str] |
the version of the model in the registry |
mlserver |
bool |
set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used. |
timeout |
int |
timeout in seconds for starting and stopping the service |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
"""MLflow model deployment configuration.
Attributes:
model_uri: URI of the MLflow model to serve
model_name: the name of the model
workers: number of workers to use for the prediction service
registry_model_name: the name of the model in the registry
registry_model_version: the version of the model in the registry
mlserver: set to True to use the MLflow MLServer backend (see
https://github.com/SeldonIO/MLServer). If False, the
MLflow built-in scoring server will be used.
timeout: timeout in seconds for starting and stopping the service
"""
model_uri: str
model_name: str
registry_model_name: Optional[str] = None
registry_model_version: Optional[str] = None
registry_model_stage: Optional[str] = None
workers: int = 1
mlserver: bool = False
timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
@field_validator("mlserver")
@classmethod
def validate_mlserver_python_version(cls, mlserver: bool) -> bool:
"""Validates the Python version if mlserver is used.
Args:
mlserver: set to True if the MLflow MLServer backend is used,
else set to False and MLflow built-in scoring server will be
used.
Returns:
the validated value
Raises:
ValueError: if mlserver is set to true on Python 3.12 as it is not
yet supported.
"""
if mlserver is True and sys.version_info.minor >= 12:
raise ValueError(
"The mlserver deployment is not yet supported on Python 3.12 "
"or above."
)
return mlserver
validate_mlserver_python_version(mlserver)
classmethod
Validates the Python version if mlserver is used.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mlserver |
bool |
set to True if the MLflow MLServer backend is used, else set to False and MLflow built-in scoring server will be used. |
required |
Returns:
Type | Description |
---|---|
bool |
the validated value |
Exceptions:
Type | Description |
---|---|
ValueError |
if mlserver is set to true on Python 3.12 as it is not yet supported. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
@field_validator("mlserver")
@classmethod
def validate_mlserver_python_version(cls, mlserver: bool) -> bool:
"""Validates the Python version if mlserver is used.
Args:
mlserver: set to True if the MLflow MLServer backend is used,
else set to False and MLflow built-in scoring server will be
used.
Returns:
the validated value
Raises:
ValueError: if mlserver is set to true on Python 3.12 as it is not
yet supported.
"""
if mlserver is True and sys.version_info.minor >= 12:
raise ValueError(
"The mlserver deployment is not yet supported on Python 3.12 "
"or above."
)
return mlserver
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint)
A service endpoint exposed by the MLflow deployment daemon.
Attributes:
Name | Type | Description |
---|---|---|
config |
MLFlowDeploymentEndpointConfig |
service endpoint configuration |
monitor |
HTTPEndpointHealthMonitor |
optional service endpoint health monitor |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
"""A service endpoint exposed by the MLflow deployment daemon.
Attributes:
config: service endpoint configuration
monitor: optional service endpoint health monitor
"""
config: MLFlowDeploymentEndpointConfig
monitor: HTTPEndpointHealthMonitor
@property
def prediction_url(self) -> Optional[str]:
"""Gets the prediction URL for the endpoint.
Returns:
the prediction URL for the endpoint
"""
uri = self.status.uri
if not uri:
return None
return os.path.join(uri, self.config.prediction_url_path)
prediction_url: Optional[str]
property
readonly
Gets the prediction URL for the endpoint.
Returns:
Type | Description |
---|---|
Optional[str] |
the prediction URL for the endpoint |
MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig)
MLflow daemon service endpoint configuration.
Attributes:
Name | Type | Description |
---|---|---|
prediction_url_path |
str |
URI subpath for prediction requests |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
"""MLflow daemon service endpoint configuration.
Attributes:
prediction_url_path: URI subpath for prediction requests
"""
prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService, BaseDeploymentService)
MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
Name | Type | Description |
---|---|---|
SERVICE_TYPE |
ClassVar[zenml.services.service_type.ServiceType] |
a service type descriptor with information describing the MLflow deployment service class |
config |
MLFlowDeploymentConfig |
service configuration |
endpoint |
MLFlowDeploymentEndpoint |
optional service endpoint |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService):
"""MLflow deployment service used to start a local prediction server for MLflow models.
Attributes:
SERVICE_TYPE: a service type descriptor with information describing
the MLflow deployment service class
config: service configuration
endpoint: optional service endpoint
"""
SERVICE_TYPE = ServiceType(
name="mlflow-deployment",
type="model-serving",
flavor="mlflow",
description="MLflow prediction service",
logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png",
)
config: MLFlowDeploymentConfig
endpoint: MLFlowDeploymentEndpoint
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
def run(self) -> None:
"""Start the service.
Raises:
ValueError: if the active stack doesn't have an MLflow experiment
tracker
"""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
backend_kwargs: Dict[str, Any] = {}
serve_kwargs: Dict[str, Any] = {}
mlflow_version = MLFLOW_VERSION.split(".")
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
serve_kwargs["timeout"] = None
# Mlflow 2.0+ requires the env_manager to be set to "local"
# to run the deploy the model on the local running environment
if int(mlflow_version[0]) >= 2:
backend_kwargs["env_manager"] = "local"
backend = PyFuncBackend( # type: ignore[no-untyped-call]
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
**backend_kwargs,
)
experiment_tracker = Client().active_stack.experiment_tracker
if not isinstance(experiment_tracker, MLFlowExperimentTracker):
raise ValueError(
"MLflow model deployer step requires an MLflow experiment "
"tracker. Please add an MLflow experiment tracker to your "
"stack."
)
experiment_tracker.configure_mlflow()
backend.serve( # type: ignore[no-untyped-call]
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
@property
def prediction_url(self) -> Optional[str]:
"""Get the URI where the prediction service is answering requests.
Returns:
The URI where the prediction service can be contacted to process
HTTP/REST inference requests, or None, if the service isn't running.
"""
if not self.is_running:
return None
return self.endpoint.prediction_url
def predict(
self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a Numpy Array or Pandas DataFrame representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
if type(request) is pd.DataFrame:
response = requests.post( # nosec
self.endpoint.prediction_url,
json={"instances": request.to_dict("records")},
)
else:
response = requests.post( # nosec
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
if int(MLFLOW_VERSION.split(".")[0]) <= 1:
return np.array(response.json())
else:
# Mlflow 2.0+ returns a dictionary with the predictions
# under the "predictions" key
return np.array(response.json()["predictions"])
prediction_url: Optional[str]
property
readonly
Get the URI where the prediction service is answering requests.
Returns:
Type | Description |
---|---|
Optional[str] |
The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running. |
__init__(self, config, **attrs)
special
Initialize the MLflow deployment service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config |
Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]] |
service configuration |
required |
attrs |
Any |
additional attributes to set on the service |
{} |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
self,
config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
**attrs: Any,
) -> None:
"""Initialize the MLflow deployment service.
Args:
config: service configuration
attrs: additional attributes to set on the service
"""
# ensure that the endpoint is created before the service is initialized
# TODO [ENG-700]: implement a service factory or builder for MLflow
# deployment services
if (
isinstance(config, MLFlowDeploymentConfig)
and "endpoint" not in attrs
):
if config.mlserver:
prediction_url_path = MLSERVER_PREDICTION_URL_PATH
healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
use_head_request = False
else:
prediction_url_path = MLFLOW_PREDICTION_URL_PATH
healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
use_head_request = True
endpoint = MLFlowDeploymentEndpoint(
config=MLFlowDeploymentEndpointConfig(
protocol=ServiceEndpointProtocol.HTTP,
prediction_url_path=prediction_url_path,
),
monitor=HTTPEndpointHealthMonitor(
config=HTTPEndpointHealthMonitorConfig(
healthcheck_uri_path=healthcheck_uri_path,
use_head_request=use_head_request,
)
),
)
attrs["endpoint"] = endpoint
super().__init__(config=config, **attrs)
predict(self, request)
Make a prediction using the service.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
request |
Union[NDArray[Any], pandas.DataFrame] |
a Numpy Array or Pandas DataFrame representing the request |
required |
Returns:
Type | Description |
---|---|
NDArray[Any] |
A numpy array representing the prediction returned by the service. |
Exceptions:
Type | Description |
---|---|
Exception |
if the service is not running |
ValueError |
if the prediction endpoint is unknown. |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(
self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
"""Make a prediction using the service.
Args:
request: a Numpy Array or Pandas DataFrame representing the request
Returns:
A numpy array representing the prediction returned by the service.
Raises:
Exception: if the service is not running
ValueError: if the prediction endpoint is unknown.
"""
if not self.is_running:
raise Exception(
"MLflow prediction service is not running. "
"Please start the service before making predictions."
)
if self.endpoint.prediction_url is not None:
if type(request) is pd.DataFrame:
response = requests.post( # nosec
self.endpoint.prediction_url,
json={"instances": request.to_dict("records")},
)
else:
response = requests.post( # nosec
self.endpoint.prediction_url,
json={"instances": request.tolist()},
)
else:
raise ValueError("No endpoint known for prediction.")
response.raise_for_status()
if int(MLFLOW_VERSION.split(".")[0]) <= 1:
return np.array(response.json())
else:
# Mlflow 2.0+ returns a dictionary with the predictions
# under the "predictions" key
return np.array(response.json()["predictions"])
run(self)
Start the service.
Exceptions:
Type | Description |
---|---|
ValueError |
if the active stack doesn't have an MLflow experiment tracker |
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
"""Start the service.
Raises:
ValueError: if the active stack doesn't have an MLflow experiment
tracker
"""
logger.info(
"Starting MLflow prediction service as blocking "
"process... press CTRL+C once to stop it."
)
self.endpoint.prepare_for_start()
try:
backend_kwargs: Dict[str, Any] = {}
serve_kwargs: Dict[str, Any] = {}
mlflow_version = MLFLOW_VERSION.split(".")
# MLflow version 1.26 introduces an additional mandatory
# `timeout` argument to the `PyFuncBackend.serve` function
if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
serve_kwargs["timeout"] = None
# Mlflow 2.0+ requires the env_manager to be set to "local"
# to run the deploy the model on the local running environment
if int(mlflow_version[0]) >= 2:
backend_kwargs["env_manager"] = "local"
backend = PyFuncBackend( # type: ignore[no-untyped-call]
config={},
no_conda=True,
workers=self.config.workers,
install_mlflow=False,
**backend_kwargs,
)
experiment_tracker = Client().active_stack.experiment_tracker
if not isinstance(experiment_tracker, MLFlowExperimentTracker):
raise ValueError(
"MLflow model deployer step requires an MLflow experiment "
"tracker. Please add an MLflow experiment tracker to your "
"stack."
)
experiment_tracker.configure_mlflow()
backend.serve( # type: ignore[no-untyped-call]
model_uri=self.config.model_uri,
port=self.endpoint.status.port,
host="localhost",
enable_mlserver=self.config.mlserver,
**serve_kwargs,
)
except KeyboardInterrupt:
logger.info(
"MLflow prediction service stopped. Resuming normal execution."
)
steps
special
Initialization of the MLflow standard interface steps.
mlflow_deployer
Implementation of the MLflow model deployer pipeline step.
mlflow_registry
Implementation of the MLflow model registration pipeline step.