Skip to content

Mlflow

zenml.integrations.mlflow special

Initialization for the ZenML MLflow integration.

The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.

MlflowIntegration (Integration)

Definition of MLflow integration for ZenML.

Source code in zenml/integrations/mlflow/__init__.py
class MlflowIntegration(Integration):
    """Definition of MLflow integration for ZenML."""

    NAME = MLFLOW

    REQUIREMENTS = [
        "mlflow>=2.1.1,<=2.14.1",
        "mlserver>=1.3.3",
        "mlserver-mlflow>=1.3.3",
        # TODO: remove this requirement once rapidjson is fixed
        "python-rapidjson<1.15",
        # When you do:
        # pip install zenml
        # You get all our required dependencies. However, if you follow it with:
        # zenml integration install mlflow
        # This downgrades pydantic to v1 even though mlflow does not have
        # any issues with v2. This is why we have to pin it here so a downgrade
        # will not happen.
        "pydantic>=2.7.0,<2.8.0"
    ]

    @classmethod
    def activate(cls) -> None:
        """Activate the MLflow integration."""
        from zenml.integrations.mlflow import services  # noqa

    @classmethod
    def flavors(cls) -> List[Type[Flavor]]:
        """Declare the stack component flavors for the MLflow integration.

        Returns:
            List of stack component flavors for this integration.
        """
        from zenml.integrations.mlflow.flavors import (
            MLFlowExperimentTrackerFlavor,
            MLFlowModelDeployerFlavor,
            MLFlowModelRegistryFlavor,
        )

        return [
            MLFlowModelDeployerFlavor,
            MLFlowExperimentTrackerFlavor,
            MLFlowModelRegistryFlavor,
        ]

activate() classmethod

Activate the MLflow integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def activate(cls) -> None:
    """Activate the MLflow integration."""
    from zenml.integrations.mlflow import services  # noqa

flavors() classmethod

Declare the stack component flavors for the MLflow integration.

Returns:

Type Description
List[Type[zenml.stack.flavor.Flavor]]

List of stack component flavors for this integration.

Source code in zenml/integrations/mlflow/__init__.py
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the MLflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.mlflow.flavors import (
        MLFlowExperimentTrackerFlavor,
        MLFlowModelDeployerFlavor,
        MLFlowModelRegistryFlavor,
    )

    return [
        MLFlowModelDeployerFlavor,
        MLFlowExperimentTrackerFlavor,
        MLFlowModelRegistryFlavor,
    ]

experiment_trackers special

Initialization of the MLflow experiment tracker.

mlflow_experiment_tracker

Implementation of the MLflow experiment tracker for ZenML.

MLFlowExperimentTracker (BaseExperimentTracker)

Track experiments using MLflow.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
class MLFlowExperimentTracker(BaseExperimentTracker):
    """Track experiments using MLflow."""

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        """Initialize the experiment tracker and validate the tracking uri.

        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(*args, **kwargs)
        self._ensure_valid_tracking_uri()

    def _ensure_valid_tracking_uri(self) -> None:
        """Ensures that the tracking uri is a valid mlflow tracking uri.

        Raises:
            ValueError: If the tracking uri is not valid.
        """
        tracking_uri = self.config.tracking_uri
        if tracking_uri:
            valid_schemes = DATABASE_ENGINES + ["http", "https", "file"]
            if not any(
                tracking_uri.startswith(scheme) for scheme in valid_schemes
            ) and not is_databricks_tracking_uri(tracking_uri):
                raise ValueError(
                    f"MLflow tracking uri does not start with one of the valid "
                    f"schemes {valid_schemes} or its value is not set to "
                    f"'databricks'. See "
                    f"https://www.mlflow.org/docs/latest/tracking.html#where-runs-are-recorded "
                    f"for more information."
                )

    @property
    def config(self) -> MLFlowExperimentTrackerConfig:
        """Returns the `MLFlowExperimentTrackerConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowExperimentTrackerConfig, self._config)

    @property
    def local_path(self) -> Optional[str]:
        """Path to the local directory where the MLflow artifacts are stored.

        Returns:
            None if configured with a remote tracking URI, otherwise the
            path to the local MLflow artifact store directory.
        """
        tracking_uri = self.get_tracking_uri()
        if is_remote_mlflow_tracking_uri(tracking_uri):
            return None
        else:
            assert tracking_uri.startswith("file:")
            return tracking_uri[5:]

    @property
    def validator(self) -> Optional["StackValidator"]:
        """Checks the stack has a `LocalArtifactStore` if no tracking uri was specified.

        Returns:
            An optional `StackValidator`.
        """
        if self.config.tracking_uri:
            # user specified a tracking uri, do nothing
            return None
        else:
            # try to fall back to a tracking uri inside the zenml artifact
            # store. this only works in case of a local artifact store, so we
            # make sure to prevent stack with other artifact stores for now
            return StackValidator(
                custom_validation_function=lambda stack: (
                    isinstance(stack.artifact_store, LocalArtifactStore),
                    "MLflow experiment tracker without a specified tracking "
                    "uri only works with a local artifact store.",
                )
            )

    @property
    def settings_class(self) -> Optional[Type["BaseSettings"]]:
        """Settings class for the Mlflow experiment tracker.

        Returns:
            The settings class.
        """
        return MLFlowExperimentTrackerSettings

    @staticmethod
    def _local_mlflow_backend() -> str:
        """Gets the local MLflow backend inside the ZenML artifact repository directory.

        Returns:
            The MLflow tracking URI for the local MLflow backend.
        """
        client = Client()
        artifact_store = client.active_stack.artifact_store
        local_mlflow_tracking_uri = os.path.join(artifact_store.path, "mlruns")
        if not os.path.exists(local_mlflow_tracking_uri):
            os.makedirs(local_mlflow_tracking_uri)
        return "file:" + local_mlflow_tracking_uri

    def get_tracking_uri(self, as_plain_text: bool = True) -> str:
        """Returns the configured tracking URI or a local fallback.

        Args:
            as_plain_text: Whether to return the tracking URI as plain text.

        Returns:
            The tracking URI.
        """
        if as_plain_text:
            tracking_uri = self.config.tracking_uri
        else:
            tracking_uri = self.config.model_dump()["tracking_uri"]
        return tracking_uri or self._local_mlflow_backend()

    def prepare_step_run(self, info: "StepRunInfo") -> None:
        """Sets the MLflow tracking uri and credentials.

        Args:
            info: Info about the step that will be executed.
        """
        self.configure_mlflow()
        settings = cast(
            MLFlowExperimentTrackerSettings,
            self.get_settings(info),
        )

        experiment_name = settings.experiment_name or info.pipeline.name
        experiment = self._set_active_experiment(experiment_name)
        run_id = self.get_run_id(
            experiment_name=experiment_name, run_name=info.run_name
        )

        tags = settings.tags.copy()
        tags.update(self._get_internal_tags())

        mlflow.start_run(
            run_id=run_id,
            run_name=info.run_name,
            experiment_id=experiment.experiment_id,
            tags=tags,
        )

        if settings.nested:
            mlflow.start_run(
                run_name=info.pipeline_step_name, nested=True, tags=tags
            )

    def get_step_run_metadata(
        self, info: "StepRunInfo"
    ) -> Dict[str, "MetadataType"]:
        """Get component- and step-specific metadata after a step ran.

        Args:
            info: Info about the step that was executed.

        Returns:
            A dictionary of metadata.
        """
        return {
            METADATA_EXPERIMENT_TRACKER_URL: Uri(
                self.get_tracking_uri(as_plain_text=False)
            ),
            "mlflow_run_id": mlflow.active_run().info.run_id,
            "mlflow_experiment_id": mlflow.active_run().info.experiment_id,
        }

    def disable_autologging(self) -> None:
        """Disables MLflow autologging for all supported frameworks."""
        frameworks = [
            "tensorflow",
            "gluon",
            "xgboost",
            "lightgbm",
            "statsmodels",
            "spark",
            "sklearn",
            "fastai",
            "pytorch",
        ]

        failed_frameworks = []

        for framework in frameworks:
            try:
                # Correctly prefix the module name with 'mlflow.'
                module_name = f"mlflow.{framework}"
                # Dynamically import the module corresponding to the framework
                module = importlib.import_module(module_name)
                # Call the autolog function with disable=True
                module.autolog(disable=True)
            except ImportError as e:
                # only log on mlflow relevant errors
                if "mlflow" in e.msg.lower():
                    failed_frameworks.append(framework)
            except Exception:
                failed_frameworks.append(framework)

        if len(failed_frameworks) > 0:
            logger.warning(
                f"Failed to disable MLflow autologging for the following frameworks: "
                f"{failed_frameworks}."
            )

    def cleanup_step_run(
        self,
        info: "StepRunInfo",
        step_failed: bool,
    ) -> None:
        """Stops active MLflow runs and resets the MLflow tracking uri.

        Args:
            info: Info about the step that was executed.
            step_failed: Whether the step failed or not.
        """
        status = "FAILED" if step_failed else "FINISHED"
        self.disable_autologging()
        mlflow_utils.stop_zenml_mlflow_runs(status)
        mlflow.set_tracking_uri("")

    def configure_mlflow(self) -> None:
        """Configures the MLflow tracking URI and any additional credentials."""
        tracking_uri = self.get_tracking_uri()
        mlflow.set_tracking_uri(tracking_uri)
        mlflow.set_registry_uri(tracking_uri)

        if is_databricks_tracking_uri(tracking_uri):
            if self.config.databricks_host:
                os.environ[DATABRICKS_HOST] = self.config.databricks_host
            if self.config.tracking_username:
                os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
            if self.config.tracking_password:
                os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
            if self.config.tracking_token:
                os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
        else:
            os.environ[MLFLOW_TRACKING_URI] = tracking_uri
            if self.config.tracking_username:
                os.environ[MLFLOW_TRACKING_USERNAME] = (
                    self.config.tracking_username
                )
            if self.config.tracking_password:
                os.environ[MLFLOW_TRACKING_PASSWORD] = (
                    self.config.tracking_password
                )
            if self.config.tracking_token:
                os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

        os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
            "true" if self.config.tracking_insecure_tls else "false"
        )

    def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
        """Gets the if of a run with the given name and experiment.

        Args:
            experiment_name: Name of the experiment in which to search for the
                run.
            run_name: Name of the run to search.

        Returns:
            The id of the run if it exists.
        """
        self.configure_mlflow()
        experiment_name = self._adjust_experiment_name(experiment_name)

        runs = mlflow.search_runs(
            experiment_names=[experiment_name],
            filter_string=f'tags.mlflow.runName = "{run_name}"',
            run_view_type=3,
            output_format="list",
        )
        if not runs:
            return None

        run: Run = runs[0]
        if mlflow_utils.is_zenml_run(run):
            return cast(str, run.info.run_id)
        else:
            return None

    def _set_active_experiment(self, experiment_name: str) -> Experiment:
        """Sets the active MLflow experiment.

        If no experiment with this name exists, it is created and then
        activated.

        Args:
            experiment_name: Name of the experiment to activate.

        Raises:
            RuntimeError: If the experiment creation or activation failed.

        Returns:
            The experiment.
        """
        experiment_name = self._adjust_experiment_name(experiment_name)

        mlflow.set_experiment(experiment_name=experiment_name)
        experiment = mlflow.get_experiment_by_name(experiment_name)
        if not experiment:
            raise RuntimeError("Failed to set active mlflow experiment.")
        return experiment

    def _adjust_experiment_name(self, experiment_name: str) -> str:
        """Prepends a slash to the experiment name if using Databricks.

        Databricks requires the experiment name to be an absolute path within
        the Databricks workspace.

        Args:
            experiment_name: The experiment name.

        Returns:
            The potentially adjusted experiment name.
        """
        tracking_uri = self.get_tracking_uri()

        if (
            tracking_uri
            and is_databricks_tracking_uri(tracking_uri)
            and not experiment_name.startswith("/")
        ):
            return f"/{experiment_name}"
        else:
            return experiment_name

    @staticmethod
    def _get_internal_tags() -> Dict[str, Any]:
        """Gets ZenML internal tags for MLflow runs.

        Returns:
            Internal tags.
        """
        return {mlflow_utils.ZENML_TAG_KEY: zenml.__version__}
config: MLFlowExperimentTrackerConfig property readonly

Returns the MLFlowExperimentTrackerConfig config.

Returns:

Type Description
MLFlowExperimentTrackerConfig

The configuration.

local_path: Optional[str] property readonly

Path to the local directory where the MLflow artifacts are stored.

Returns:

Type Description
Optional[str]

None if configured with a remote tracking URI, otherwise the path to the local MLflow artifact store directory.

settings_class: Optional[Type[BaseSettings]] property readonly

Settings class for the Mlflow experiment tracker.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property readonly

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

__init__(self, *args, **kwargs) special

Initialize the experiment tracker and validate the tracking uri.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the experiment tracker and validate the tracking uri.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self._ensure_valid_tracking_uri()
cleanup_step_run(self, info, step_failed)

Stops active MLflow runs and resets the MLflow tracking uri.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required
step_failed bool

Whether the step failed or not.

required
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def cleanup_step_run(
    self,
    info: "StepRunInfo",
    step_failed: bool,
) -> None:
    """Stops active MLflow runs and resets the MLflow tracking uri.

    Args:
        info: Info about the step that was executed.
        step_failed: Whether the step failed or not.
    """
    status = "FAILED" if step_failed else "FINISHED"
    self.disable_autologging()
    mlflow_utils.stop_zenml_mlflow_runs(status)
    mlflow.set_tracking_uri("")
configure_mlflow(self)

Configures the MLflow tracking URI and any additional credentials.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    tracking_uri = self.get_tracking_uri()
    mlflow.set_tracking_uri(tracking_uri)
    mlflow.set_registry_uri(tracking_uri)

    if is_databricks_tracking_uri(tracking_uri):
        if self.config.databricks_host:
            os.environ[DATABRICKS_HOST] = self.config.databricks_host
        if self.config.tracking_username:
            os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
        if self.config.tracking_password:
            os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
        if self.config.tracking_token:
            os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
    else:
        os.environ[MLFLOW_TRACKING_URI] = tracking_uri
        if self.config.tracking_username:
            os.environ[MLFLOW_TRACKING_USERNAME] = (
                self.config.tracking_username
            )
        if self.config.tracking_password:
            os.environ[MLFLOW_TRACKING_PASSWORD] = (
                self.config.tracking_password
            )
        if self.config.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.config.tracking_insecure_tls else "false"
    )
disable_autologging(self)

Disables MLflow autologging for all supported frameworks.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def disable_autologging(self) -> None:
    """Disables MLflow autologging for all supported frameworks."""
    frameworks = [
        "tensorflow",
        "gluon",
        "xgboost",
        "lightgbm",
        "statsmodels",
        "spark",
        "sklearn",
        "fastai",
        "pytorch",
    ]

    failed_frameworks = []

    for framework in frameworks:
        try:
            # Correctly prefix the module name with 'mlflow.'
            module_name = f"mlflow.{framework}"
            # Dynamically import the module corresponding to the framework
            module = importlib.import_module(module_name)
            # Call the autolog function with disable=True
            module.autolog(disable=True)
        except ImportError as e:
            # only log on mlflow relevant errors
            if "mlflow" in e.msg.lower():
                failed_frameworks.append(framework)
        except Exception:
            failed_frameworks.append(framework)

    if len(failed_frameworks) > 0:
        logger.warning(
            f"Failed to disable MLflow autologging for the following frameworks: "
            f"{failed_frameworks}."
        )
get_run_id(self, experiment_name, run_name)

Gets the if of a run with the given name and experiment.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment in which to search for the run.

required
run_name str

Name of the run to search.

required

Returns:

Type Description
Optional[str]

The id of the run if it exists.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
    """Gets the if of a run with the given name and experiment.

    Args:
        experiment_name: Name of the experiment in which to search for the
            run.
        run_name: Name of the run to search.

    Returns:
        The id of the run if it exists.
    """
    self.configure_mlflow()
    experiment_name = self._adjust_experiment_name(experiment_name)

    runs = mlflow.search_runs(
        experiment_names=[experiment_name],
        filter_string=f'tags.mlflow.runName = "{run_name}"',
        run_view_type=3,
        output_format="list",
    )
    if not runs:
        return None

    run: Run = runs[0]
    if mlflow_utils.is_zenml_run(run):
        return cast(str, run.info.run_id)
    else:
        return None
get_step_run_metadata(self, info)

Get component- and step-specific metadata after a step ran.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_step_run_metadata(
    self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
    """Get component- and step-specific metadata after a step ran.

    Args:
        info: Info about the step that was executed.

    Returns:
        A dictionary of metadata.
    """
    return {
        METADATA_EXPERIMENT_TRACKER_URL: Uri(
            self.get_tracking_uri(as_plain_text=False)
        ),
        "mlflow_run_id": mlflow.active_run().info.run_id,
        "mlflow_experiment_id": mlflow.active_run().info.experiment_id,
    }
get_tracking_uri(self, as_plain_text=True)

Returns the configured tracking URI or a local fallback.

Parameters:

Name Type Description Default
as_plain_text bool

Whether to return the tracking URI as plain text.

True

Returns:

Type Description
str

The tracking URI.

Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def get_tracking_uri(self, as_plain_text: bool = True) -> str:
    """Returns the configured tracking URI or a local fallback.

    Args:
        as_plain_text: Whether to return the tracking URI as plain text.

    Returns:
        The tracking URI.
    """
    if as_plain_text:
        tracking_uri = self.config.tracking_uri
    else:
        tracking_uri = self.config.model_dump()["tracking_uri"]
    return tracking_uri or self._local_mlflow_backend()
prepare_step_run(self, info)

Sets the MLflow tracking uri and credentials.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that will be executed.

required
Source code in zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
def prepare_step_run(self, info: "StepRunInfo") -> None:
    """Sets the MLflow tracking uri and credentials.

    Args:
        info: Info about the step that will be executed.
    """
    self.configure_mlflow()
    settings = cast(
        MLFlowExperimentTrackerSettings,
        self.get_settings(info),
    )

    experiment_name = settings.experiment_name or info.pipeline.name
    experiment = self._set_active_experiment(experiment_name)
    run_id = self.get_run_id(
        experiment_name=experiment_name, run_name=info.run_name
    )

    tags = settings.tags.copy()
    tags.update(self._get_internal_tags())

    mlflow.start_run(
        run_id=run_id,
        run_name=info.run_name,
        experiment_id=experiment.experiment_id,
        tags=tags,
    )

    if settings.nested:
        mlflow.start_run(
            run_name=info.pipeline_step_name, nested=True, tags=tags
        )

flavors special

MLFlow integration flavors.

mlflow_experiment_tracker_flavor

MLflow experiment tracker flavor.

MLFlowExperimentTrackerConfig (BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings)

Config for the MLflow experiment tracker.

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

databricks_host Optional[str]

The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if tracking_uri value is set to "databricks".

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerConfig(
    BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings
):
    """Config for the MLflow experiment tracker.

    Attributes:
        tracking_uri: The uri of the mlflow tracking server. If no uri is set,
            your stack must contain a `LocalArtifactStore` and ZenML will
            point MLflow to a subdirectory of your artifact store instead.
        tracking_username: Username for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_password: Password for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_token: Token for authenticating with the MLflow
            tracking server. When a remote tracking uri is specified,
            either `tracking_token` or `tracking_username` and
            `tracking_password` must be specified.
        tracking_insecure_tls: Skips verification of TLS connection to the
            MLflow tracking server if set to `True`.
        databricks_host: The host of the Databricks workspace with the MLflow
            managed server to connect to. This is only required if
            `tracking_uri` value is set to `"databricks"`.
    """

    tracking_uri: Optional[str] = None
    tracking_username: Optional[str] = SecretField(default=None)
    tracking_password: Optional[str] = SecretField(default=None)
    tracking_token: Optional[str] = SecretField(default=None)
    tracking_insecure_tls: bool = False
    databricks_host: Optional[str] = None

    @model_validator(mode="after")
    def _ensure_authentication_if_necessary(
        self,
    ) -> "MLFlowExperimentTrackerConfig":
        """Ensures that credentials or a token for authentication exist.

        We make this check when running MLflow tracking with a remote backend.

        Returns:
            The validated values.

        Raises:
             ValueError: If neither credentials nor a token are provided.
        """
        if self.tracking_uri:
            if is_databricks_tracking_uri(self.tracking_uri):
                # If the tracking uri is "databricks", then we need the
                # databricks host to be set.
                if not self.databricks_host:
                    raise ValueError(
                        "MLflow experiment tracking with a Databricks MLflow "
                        "managed tracking server requires the "
                        "`databricks_host` to be set in your stack component. "
                        "To update your component, run "
                        "`zenml experiment-tracker update "
                        "<NAME> --databricks_host=DATABRICKS_HOST` "
                        "and specify the hostname of your Databricks workspace."
                    )

            if is_remote_mlflow_tracking_uri(self.tracking_uri):
                # we need either username + password or a token to authenticate
                # to the remote backend
                basic_auth = self.tracking_username and self.tracking_password

                if not (basic_auth or self.tracking_token):
                    raise ValueError(
                        f"MLflow experiment tracking with a remote backend "
                        f"{self.tracking_uri} is only possible when specifying "
                        f"either username and password or an authentication "
                        f"token in your stack component. To update your "
                        f"component, run the following command: "
                        f"`zenml experiment-tracker update "
                        f"<NAME> --tracking_username=MY_USERNAME "
                        f"--tracking_password=MY_PASSWORD "
                        f"--tracking_token=MY_TOKEN` and specify either your "
                        f"username and password or token."
                    )

        return self

    @property
    def is_local(self) -> bool:
        """Checks if this stack component is running locally.

        Returns:
            True if this config is for a local component, False otherwise.
        """
        if not self.tracking_uri or not is_remote_mlflow_tracking_uri(
            self.tracking_uri
        ):
            return True
        return False
is_local: bool property readonly

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowExperimentTrackerFlavor (BaseExperimentTrackerFlavor)

Class for the MLFlowExperimentTrackerFlavor.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerFlavor(BaseExperimentTrackerFlavor):
    """Class for the `MLFlowExperimentTrackerFlavor`."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowExperimentTrackerConfig]:
        """Returns `MLFlowExperimentTrackerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowExperimentTrackerConfig

    @property
    def implementation_class(self) -> Type["MLFlowExperimentTracker"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.experiment_trackers import (
            MLFlowExperimentTracker,
        )

        return MLFlowExperimentTracker
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig] property readonly

Returns MLFlowExperimentTrackerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_experiment_tracker_flavor.MLFlowExperimentTrackerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowExperimentTracker] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowExperimentTracker]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

MLFlowExperimentTrackerSettings (BaseSettings)

Settings for the MLflow experiment tracker.

Attributes:

Name Type Description
experiment_name Optional[str]

The MLflow experiment name.

nested bool

If True, will create a nested sub-run for the step.

tags Dict[str, Any]

Tags for the Mlflow run.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
class MLFlowExperimentTrackerSettings(BaseSettings):
    """Settings for the MLflow experiment tracker.

    Attributes:
        experiment_name: The MLflow experiment name.
        nested: If `True`, will create a nested sub-run for the step.
        tags: Tags for the Mlflow run.
    """

    experiment_name: Optional[str] = None
    nested: bool = False
    tags: Dict[str, Any] = {}
is_databricks_tracking_uri(tracking_uri)

Checks whether the given tracking uri is a Databricks tracking uri.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is a Databricks tracking uri, False otherwise.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_databricks_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is a Databricks tracking uri.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is a Databricks tracking uri, `False`
        otherwise.
    """
    return tracking_uri == "databricks"
is_remote_mlflow_tracking_uri(tracking_uri)

Checks whether the given tracking uri is remote or not.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is remote, False otherwise.

Source code in zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
def is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is remote or not.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is remote, `False` otherwise.
    """
    return any(
        tracking_uri.startswith(prefix) for prefix in ["http://", "https://"]
    ) or is_databricks_tracking_uri(tracking_uri)

mlflow_model_deployer_flavor

MLflow model deployer flavor.

MLFlowModelDeployerConfig (BaseModelDeployerConfig)

Configuration for the MLflow model deployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service configuration, PID and log files are stored.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerConfig(BaseModelDeployerConfig):
    """Configuration for the MLflow model deployer.

    Attributes:
        service_path: the path where the local MLflow deployment service
            configuration, PID and log files are stored.
    """

    service_path: str = ""

    @property
    def is_local(self) -> bool:
        """Checks if this stack component is running locally.

        Returns:
            True if this config is for a local component, False otherwise.
        """
        return True
is_local: bool property readonly

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowModelDeployerFlavor (BaseModelDeployerFlavor)

Model deployer flavor for MLflow models.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
    """Model deployer flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_DEPLOYER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelDeployerConfig]:
        """Returns `MLFlowModelDeployerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelDeployerConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelDeployer"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_deployers import (
            MLFlowModelDeployer,
        )

        return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] property readonly

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_model_registry_flavor

MLflow model registry flavor.

MLFlowModelRegistryConfig (BaseModelRegistryConfig)

Configuration for the MLflow model registry.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryConfig(BaseModelRegistryConfig):
    """Configuration for the MLflow model registry."""
MLFlowModelRegistryFlavor (BaseModelRegistryFlavor)

Model registry flavor for MLflow models.

Source code in zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py
class MLFlowModelRegistryFlavor(BaseModelRegistryFlavor):
    """Model registry flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_REGISTRY_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelRegistryConfig]:
        """Returns `MLFlowModelRegistryConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelRegistryConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelRegistry"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_registries import (
            MLFlowModelRegistry,
        )

        return MLFlowModelRegistry
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig] property readonly

Returns MLFlowModelRegistryConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_registry_flavor.MLFlowModelRegistryConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelRegistry] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelRegistry]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_utils

Implementation of utils specific to the MLflow integration.

get_missing_mlflow_experiment_tracker_error()

Returns description of how to add an MLflow experiment tracker to your stack.

Returns:

Type Description
ValueError

If no MLflow experiment tracker is registered in the active stack.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
    """Returns description of how to add an MLflow experiment tracker to your stack.

    Returns:
        ValueError: If no MLflow experiment tracker is registered in the active stack.
    """
    return ValueError(
        "The active stack needs to have a MLflow experiment tracker "
        "component registered to be able to track experiments using "
        "MLflow. You can create a new stack with a MLflow experiment "
        "tracker component or update your existing stack to add this "
        "component, e.g.:\n\n"
        "  'zenml experiment-tracker register mlflow_tracker "
        "--type=mlflow'\n"
        "  'zenml stack register stack-name -e mlflow_tracker ...'\n"
    )

get_tracking_uri()

Gets the MLflow tracking URI from the active experiment tracking stack component.

noqa: DAR401

Returns:

Type Description
str

MLflow tracking URI.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def get_tracking_uri() -> str:
    """Gets the MLflow tracking URI from the active experiment tracking stack component.

    # noqa: DAR401

    Returns:
        MLflow tracking URI.
    """
    from zenml.integrations.mlflow.experiment_trackers.mlflow_experiment_tracker import (
        MLFlowExperimentTracker,
    )

    tracker = Client().active_stack.experiment_tracker
    if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    return tracker.get_tracking_uri()

is_zenml_run(run)

Checks if a MLflow run is a ZenML run or not.

Parameters:

Name Type Description Default
run mlflow.entities.Run

The run to check.

required

Returns:

Type Description
bool

If the run is a ZenML run.

Source code in zenml/integrations/mlflow/mlflow_utils.py
def is_zenml_run(run: Run) -> bool:
    """Checks if a MLflow run is a ZenML run or not.

    Args:
        run: The run to check.

    Returns:
        If the run is a ZenML run.
    """
    return ZENML_TAG_KEY in run.data.tags

stop_zenml_mlflow_runs(status)

Stops active ZenML Mlflow runs.

This function stops all MLflow active runs until no active run exists or a non-ZenML run is active.

Parameters:

Name Type Description Default
status str

The status to set the run to.

required
Source code in zenml/integrations/mlflow/mlflow_utils.py
def stop_zenml_mlflow_runs(status: str) -> None:
    """Stops active ZenML Mlflow runs.

    This function stops all MLflow active runs until no active run exists or
    a non-ZenML run is active.

    Args:
        status: The status to set the run to.
    """
    active_run = mlflow.active_run()
    while active_run:
        if is_zenml_run(active_run):
            logger.debug("Stopping mlflow run %s.", active_run.info.run_id)
            mlflow.end_run(status=status)
            active_run = mlflow.active_run()
        else:
            break

model_deployers special

Initialization of the MLflow model deployers.

mlflow_model_deployer

Implementation of the MLflow model deployer.

MLFlowModelDeployer (BaseModelDeployer)

MLflow implementation of the BaseModelDeployer.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployer(BaseModelDeployer):
    """MLflow implementation of the BaseModelDeployer."""

    NAME: ClassVar[str] = "MLflow"
    FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = MLFlowModelDeployerFlavor

    _service_path: Optional[str] = None

    @property
    def config(self) -> MLFlowModelDeployerConfig:
        """Returns the `MLFlowModelDeployerConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowModelDeployerConfig, self._config)

    @staticmethod
    def get_service_path(id_: UUID) -> str:
        """Get the path where local MLflow service information is stored.

        This includes the deployment service configuration, PID and log files
        are stored.

        Args:
            id_: The ID of the MLflow model deployer.

        Returns:
            The service path.
        """
        service_path = os.path.join(
            GlobalConfiguration().local_stores_path,
            str(id_),
        )
        create_dir_recursive_if_not_exists(service_path)
        return service_path

    @property
    def local_path(self) -> str:
        """Returns the path to the root directory.

        This is where all configurations for MLflow deployment daemon processes
        are stored.

        If the service path is not set in the config by the user, the path is
        set to a local default path according to the component ID.

        Returns:
            The path to the local service root directory.
        """
        if self._service_path is not None:
            return self._service_path

        if self.config.service_path:
            self._service_path = self.config.service_path
        else:
            self._service_path = self.get_service_path(self.id)

        create_dir_recursive_if_not_exists(self._service_path)
        return self._service_path

    @staticmethod
    def get_model_server_info(  # type: ignore[override]
        service_instance: "MLFlowDeploymentService",
    ) -> Dict[str, Optional[str]]:
        """Return implementation specific information relevant to the user.

        Args:
            service_instance: Instance of a SeldonDeploymentService

        Returns:
            A dictionary containing the information.
        """
        return {
            "PREDICTION_URL": service_instance.endpoint.prediction_url,
            "MODEL_URI": service_instance.config.model_uri,
            "MODEL_NAME": service_instance.config.model_name,
            "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
            "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
            "SERVICE_PATH": service_instance.status.runtime_path,
            "DAEMON_PID": str(service_instance.status.pid),
            "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
                service_instance.endpoint
            ),
        }

    def perform_deploy_model(
        self,
        id: UUID,
        config: ServiceConfig,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Create a new MLflow deployment service or update an existing one.

        This should serve the supplied model and deployment configuration.

        This method has two modes of operation, depending on the `replace`
        argument value:

          * if `replace` is False, calling this method will create a new MLflow
            deployment server to reflect the model and other configuration
            parameters specified in the supplied MLflow service `config`.

          * if `replace` is True, this method will first attempt to find an
            existing MLflow deployment service that is *equivalent* to the
            supplied configuration parameters. Two or more MLflow deployment
            services are considered equivalent if they have the same
            `pipeline_name`, `pipeline_step_name` and `model_name` configuration
            parameters. To put it differently, two MLflow deployment services
            are equivalent if they serve versions of the same model deployed by
            the same pipeline step. If an equivalent MLflow deployment is found,
            it will be updated in place to reflect the new configuration
            parameters.

        Callers should set `replace` to True if they want a continuous model
        deployment workflow that doesn't spin up a new MLflow deployment
        server for each new model version. If multiple equivalent MLflow
        deployment servers are found, one is selected at random to be updated
        and the others are deleted.

        Args:
            id: the ID of the MLflow deployment service to be created or updated.
            config: the configuration of the model to be deployed with MLflow.
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated. If set
                to 0, the method will return immediately after the MLflow
                server is provisioned, without waiting for it to fully start.

        Returns:
            The ZenML MLflow deployment service object that can be used to
            interact with the MLflow model server.
        """
        config = cast(MLFlowDeploymentConfig, config)
        service = self._create_new_service(
            id=id, timeout=timeout, config=config
        )
        logger.info(f"Created a new MLflow deployment service: {service}")
        return service

    def _clean_up_existing_service(
        self,
        timeout: int,
        force: bool,
        existing_service: MLFlowDeploymentService,
    ) -> None:
        # stop the older service
        existing_service.stop(timeout=timeout, force=force)

        # delete the old configuration file
        if existing_service.status.runtime_path:
            shutil.rmtree(existing_service.status.runtime_path)

    # the step will receive a config from the user that mentions the number
    # of workers etc.the step implementation will create a new config using
    # all values from the user and add values like pipeline name, model_uri
    def _create_new_service(
        self, id: UUID, timeout: int, config: MLFlowDeploymentConfig
    ) -> MLFlowDeploymentService:
        """Creates a new MLFlowDeploymentService.

        Args:
            id: the ID of the MLflow deployment service to be created or updated.
            timeout: the timeout in seconds to wait for the MLflow server
                to be provisioned and successfully started or updated.
            config: the configuration of the model to be deployed with MLflow.

        Returns:
            The MLFlowDeploymentService object that can be used to interact
            with the MLflow model server.
        """
        # set the root runtime path with the stack component's UUID
        config.root_runtime_path = self.local_path
        # create a new service for the new model
        service = MLFlowDeploymentService(uuid=id, config=config)
        service.start(timeout=timeout)

        return service

    def perform_stop_model(
        self,
        service: BaseService,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> BaseService:
        """Method to stop a model server.

        Args:
            service: The service to stop.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.

        Returns:
            The service that was stopped.
        """
        service.stop(timeout=timeout, force=force)
        return service

    def perform_start_model(
        self,
        service: BaseService,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    ) -> BaseService:
        """Method to start a model server.

        Args:
            service: The service to start.
            timeout: Timeout in seconds to wait for the service to start.

        Returns:
            The service that was started.
        """
        service.start(timeout=timeout)
        return service

    def perform_delete_model(
        self,
        service: BaseService,
        timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
        force: bool = False,
    ) -> None:
        """Method to delete all configuration of a model server.

        Args:
            service: The service to delete.
            timeout: Timeout in seconds to wait for the service to stop.
            force: If True, force the service to stop.
        """
        service = cast(MLFlowDeploymentService, service)
        self._clean_up_existing_service(
            existing_service=service, timeout=timeout, force=force
        )
config: MLFlowModelDeployerConfig property readonly

Returns the MLFlowModelDeployerConfig config.

Returns:

Type Description
MLFlowModelDeployerConfig

The configuration.

local_path: str property readonly

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.

Returns:

Type Description
str

The path to the local service root directory.

FLAVOR (BaseModelDeployerFlavor)

Model deployer flavor for MLflow models.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
class MLFlowModelDeployerFlavor(BaseModelDeployerFlavor):
    """Model deployer flavor for MLflow models."""

    @property
    def name(self) -> str:
        """Name of the flavor.

        Returns:
            The name of the flavor.
        """
        return MLFLOW_MODEL_DEPLOYER_FLAVOR

    @property
    def docs_url(self) -> Optional[str]:
        """A url to point at docs explaining this flavor.

        Returns:
            A flavor docs url.
        """
        return self.generate_default_docs_url()

    @property
    def sdk_docs_url(self) -> Optional[str]:
        """A url to point at SDK docs explaining this flavor.

        Returns:
            A flavor SDK docs url.
        """
        return self.generate_default_sdk_docs_url()

    @property
    def logo_url(self) -> str:
        """A url to represent the flavor in the dashboard.

        Returns:
            The flavor logo.
        """
        return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png"

    @property
    def config_class(self) -> Type[MLFlowModelDeployerConfig]:
        """Returns `MLFlowModelDeployerConfig` config class.

        Returns:
                The config class.
        """
        return MLFlowModelDeployerConfig

    @property
    def implementation_class(self) -> Type["MLFlowModelDeployer"]:
        """Implementation class for this flavor.

        Returns:
            The implementation class.
        """
        from zenml.integrations.mlflow.model_deployers import (
            MLFlowModelDeployer,
        )

        return MLFlowModelDeployer
config_class: Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig] property readonly

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[zenml.integrations.mlflow.flavors.mlflow_model_deployer_flavor.MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property readonly

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property readonly

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property readonly

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property readonly

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property readonly

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

get_model_server_info(service_instance) staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
        "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
        "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
            service_instance.endpoint
        ),
    }
get_service_path(id_) staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
id_ UUID

The ID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
@staticmethod
def get_service_path(id_: UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files
    are stored.

    Args:
        id_: The ID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        GlobalConfiguration().local_stores_path,
        str(id_),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
perform_delete_model(self, service, timeout=60, force=False)

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
service BaseService

The service to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

60
force bool

If True, force the service to stop.

False
Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_delete_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        service: The service to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    service = cast(MLFlowDeploymentService, service)
    self._clean_up_existing_service(
        existing_service=service, timeout=timeout, force=force
    )
perform_deploy_model(self, id, config, timeout=60)

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
id UUID

the ID of the MLflow deployment service to be created or updated.

required
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

60

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to interact with the MLflow model server.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_deploy_model(
    self,
    id: UUID,
    config: ServiceConfig,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        id: the ID of the MLflow deployment service to be created or updated.
        config: the configuration of the model to be deployed with MLflow.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = self._create_new_service(
        id=id, timeout=timeout, config=config
    )
    logger.info(f"Created a new MLflow deployment service: {service}")
    return service
perform_start_model(self, service, timeout=60)

Method to start a model server.

Parameters:

Name Type Description Default
service BaseService

The service to start.

required
timeout int

Timeout in seconds to wait for the service to start.

60

Returns:

Type Description
BaseService

The service that was started.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_start_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Method to start a model server.

    Args:
        service: The service to start.
        timeout: Timeout in seconds to wait for the service to start.

    Returns:
        The service that was started.
    """
    service.start(timeout=timeout)
    return service
perform_stop_model(self, service, timeout=60, force=False)

Method to stop a model server.

Parameters:

Name Type Description Default
service BaseService

The service to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

60
force bool

If True, force the service to stop.

False

Returns:

Type Description
BaseService

The service that was stopped.

Source code in zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
def perform_stop_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> BaseService:
    """Method to stop a model server.

    Args:
        service: The service to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.

    Returns:
        The service that was stopped.
    """
    service.stop(timeout=timeout, force=force)
    return service

model_registries special

Initialization of the MLflow model registry.

mlflow_model_registry

Implementation of the MLflow model registry for ZenML.

MLFlowModelRegistry (BaseModelRegistry)

Register models using MLflow.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
class MLFlowModelRegistry(BaseModelRegistry):
    """Register models using MLflow."""

    _client: Optional[MlflowClient] = None

    @property
    def config(self) -> MLFlowModelRegistryConfig:
        """Returns the `MLFlowModelRegistryConfig` config.

        Returns:
            The configuration.
        """
        return cast(MLFlowModelRegistryConfig, self._config)

    def configure_mlflow(self) -> None:
        """Configures the MLflow Client with the experiment tracker config."""
        experiment_tracker = Client().active_stack.experiment_tracker
        assert isinstance(experiment_tracker, MLFlowExperimentTracker)
        experiment_tracker.configure_mlflow()

    @property
    def mlflow_client(self) -> MlflowClient:
        """Get the MLflow client.

        Returns:
            The MLFlowClient.
        """
        if not self._client:
            self.configure_mlflow()
            self._client = mlflow.tracking.MlflowClient()
        return self._client

    @property
    def validator(self) -> Optional[StackValidator]:
        """Validates that the stack contains an mlflow experiment tracker.

        Returns:
            A StackValidator instance.
        """

        def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]:
            """Validates that all the requirements are met for the stack.

            Args:
                stack: The stack to validate.

            Returns:
                A tuple of (is_valid, error_message).
            """
            # Validate that the experiment tracker is an mlflow experiment tracker.
            experiment_tracker = stack.experiment_tracker
            assert experiment_tracker is not None
            if experiment_tracker.flavor != "mlflow":
                return False, (
                    "The MLflow model registry requires a MLflow experiment "
                    "tracker. You should register a MLflow experiment "
                    "tracker to the stack using the following command: "
                    "`zenml stack update model_registry -e mlflow_tracker"
                )
            mlflow_version = mlflow.version.VERSION
            if (
                not mlflow_version >= "2.1.1"
                and experiment_tracker.config.is_local
            ):
                return False, (
                    "The MLflow model registry requires MLflow version "
                    f"2.1.1 or higher to use a local MLflow registry. "
                    f"Your current MLflow version is {mlflow_version}."
                    "You can upgrade MLflow using the following command: "
                    "`pip install --upgrade mlflow`"
                )
            return True, ""

        return StackValidator(
            required_components={
                StackComponentType.EXPERIMENT_TRACKER,
            },
            custom_validation_function=_validate_stack_requirements,
        )

    # ---------
    # Model Registration Methods
    # ---------

    def register_model(
        self,
        name: str,
        description: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> RegisteredModel:
        """Register a model to the MLflow model registry.

        Args:
            name: The name of the model.
            description: The description of the model.
            metadata: The metadata of the model.

        Raises:
            RuntimeError: If the model already exists.

        Returns:
            The registered model.
        """
        # Check if model already exists.
        try:
            self.get_model(name)
            raise KeyError(
                f"Model with name {name} already exists in the MLflow model "
                f"registry. Please use a different name.",
            )
        except KeyError:
            pass
        # Register model.
        try:
            registered_model = self.mlflow_client.create_registered_model(
                name=name,
                description=description,
                tags=metadata,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to register model with name {name} to the MLflow "
                f"model registry: {str(e)}",
            )

        # Return the registered model.
        return RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )

    def delete_model(
        self,
        name: str,
    ) -> None:
        """Delete a model from the MLflow model registry.

        Args:
            name: The name of the model.

        Raises:
            RuntimeError: If the model does not exist.
        """
        # Check if model exists.
        self.get_model(name=name)
        # Delete the registered model.
        try:
            self.mlflow_client.delete_registered_model(
                name=name,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to delete model with name {name} from MLflow model "
                f"registry: {str(e)}",
            )

    def update_model(
        self,
        name: str,
        description: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
        remove_metadata: Optional[List[str]] = None,
    ) -> RegisteredModel:
        """Update a model in the MLflow model registry.

        Args:
            name: The name of the model.
            description: The description of the model.
            metadata: The metadata of the model.
            remove_metadata: The metadata to remove from the model.

        Raises:
            RuntimeError: If mlflow fails to update the model.

        Returns:
            The updated model.
        """
        # Check if model exists.
        self.get_model(name=name)
        # Update the registered model description.
        if description:
            try:
                self.mlflow_client.update_registered_model(
                    name=name,
                    description=description,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update description for the model {name} in MLflow "
                    f"model registry: {str(e)}",
                )
        # Update the registered model tags.
        if metadata:
            try:
                for tag, value in metadata.items():
                    self.mlflow_client.set_registered_model_tag(
                        name=name,
                        key=tag,
                        value=value,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update tags for the model {name} in MLflow model "
                    f"registry: {str(e)}",
                )
        # Remove tags from the registered model.
        if remove_metadata:
            try:
                for tag in remove_metadata:
                    self.mlflow_client.delete_registered_model_tag(
                        name=name,
                        key=tag,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to remove tags for the model {name} in MLflow model "
                    f"registry: {str(e)}",
                )
        # Return the updated registered model.
        return self.get_model(name)

    def get_model(self, name: str) -> RegisteredModel:
        """Get a model from the MLflow model registry.

        Args:
            name: The name of the model.

        Returns:
            The model.

        Raises:
            KeyError: If mlflow fails to get the model.
        """
        # Get the registered model.
        try:
            registered_model = self.mlflow_client.get_registered_model(
                name=name,
            )
        except MlflowException as e:
            raise KeyError(
                f"Failed to get model with name {name} from the MLflow model "
                f"registry: {str(e)}",
            )
        # Return the registered model.
        return RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )

    def list_models(
        self,
        name: Optional[str] = None,
        metadata: Optional[Dict[str, str]] = None,
    ) -> List[RegisteredModel]:
        """List models in the MLflow model registry.

        Args:
            name: A name to filter the models by.
            metadata: The metadata to filter the models by.

        Returns:
            A list of models (RegisteredModel)
        """
        # Set the filter string.
        filter_string = ""
        if name:
            filter_string += f"name='{name}'"
        if metadata:
            for tag, value in metadata.items():
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"

        # Get the registered models.
        registered_models = self.mlflow_client.search_registered_models(
            filter_string=filter_string,
            max_results=100,
        )
        # Return the registered models.
        return [
            RegisteredModel(
                name=registered_model.name,
                description=registered_model.description,
                metadata=registered_model.tags,
            )
            for registered_model in registered_models
        ]

    # ---------
    # Model Version Methods
    # ---------

    def register_model_version(
        self,
        name: str,
        version: Optional[str] = None,
        model_source_uri: Optional[str] = None,
        description: Optional[str] = None,
        metadata: Optional[ModelRegistryModelMetadata] = None,
        **kwargs: Any,
    ) -> RegistryModelVersion:
        """Register a model version to the MLflow model registry.

        Args:
            name: The name of the model.
            model_source_uri: The source URI of the model.
            version: The version of the model.
            description: The description of the model version.
            metadata: The registry metadata of the model version.
            **kwargs: Additional keyword arguments.

        Raises:
            RuntimeError: If the registered model does not exist.

        Returns:
            The registered model version.
        """
        # Check if the model exists, if not create it.
        try:
            self.get_model(name=name)
        except KeyError:
            logger.info(
                f"No registered model with name {name} found. Creating a new"
                "registered model."
            )
            self.register_model(
                name=name,
            )
        try:
            # Inform the user that the version is ignored.
            if version:
                logger.info(
                    f"MLflow model registry does not take a version as an argument. "
                    f"Registering a new version for the model `'{name}'` "
                    f"a version will be assigned automatically."
                )
            metadata_dict = metadata.model_dump() if metadata else {}
            # Set the run ID and link.
            run_id = metadata_dict.get("mlflow_run_id", None)
            run_link = metadata_dict.get("mlflow_run_link", None)
            # Register the model version.
            registered_model_version = self.mlflow_client.create_model_version(
                name=name,
                source=model_source_uri,
                run_id=run_id,
                run_link=run_link,
                description=description,
                tags=metadata_dict,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to register model version with name '{name}' and "
                f"version '{version}' to the MLflow model registry."
                f"Error: {e}"
            )
        # Return the registered model version.
        return self._cast_mlflow_version_to_model_version(
            registered_model_version
        )

    def delete_model_version(
        self,
        name: str,
        version: str,
    ) -> None:
        """Delete a model version from the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.

        Raises:
            RuntimeError: If mlflow fails to delete the model version.
        """
        self.get_model_version(name=name, version=version)
        try:
            self.mlflow_client.delete_model_version(
                name=name,
                version=version,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to delete model version '{version}' of model '{name}'."
                f"From the MLflow model registry: {str(e)}",
            )

    def update_model_version(
        self,
        name: str,
        version: str,
        description: Optional[str] = None,
        metadata: Optional[ModelRegistryModelMetadata] = None,
        remove_metadata: Optional[List[str]] = None,
        stage: Optional[ModelVersionStage] = None,
    ) -> RegistryModelVersion:
        """Update a model version in the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.
            description: The description of the model version.
            metadata: The metadata of the model version.
            remove_metadata: The metadata to remove from the model version.
            stage: The stage of the model version.

        Raises:
            RuntimeError: If mlflow fails to update the model version.

        Returns:
            The updated model version.
        """
        self.get_model_version(name=name, version=version)
        # Update the model description.
        if description:
            try:
                self.mlflow_client.update_model_version(
                    name=name,
                    version=version,
                    description=description,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the description of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Update the model tags.
        if metadata:
            try:
                for key, value in metadata.model_dump().items():
                    self.mlflow_client.set_model_version_tag(
                        name=name,
                        version=version,
                        key=key,
                        value=value,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the tags of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Remove the model tags.
        if remove_metadata:
            try:
                for key in remove_metadata:
                    self.mlflow_client.delete_model_version_tag(
                        name=name,
                        version=version,
                        key=key,
                    )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to remove the tags of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        # Update the model stage.
        if stage:
            try:
                self.mlflow_client.transition_model_version_stage(
                    name=name,
                    version=version,
                    stage=stage.value,
                )
            except MlflowException as e:
                raise RuntimeError(
                    f"Failed to update the current stage of model version "
                    f"'{name}:{version}' in the MLflow model registry: {str(e)}"
                )
        return self.get_model_version(name, version)

    def get_model_version(
        self,
        name: str,
        version: str,
    ) -> RegistryModelVersion:
        """Get a model version from the MLflow model registry.

        Args:
            name: The name of the model.
            version: The version of the model.

        Raises:
            KeyError: If the model version does not exist.

        Returns:
            The model version.
        """
        # Get the model version from the MLflow model registry.
        try:
            mlflow_model_version = self.mlflow_client.get_model_version(
                name=name,
                version=version,
            )
        except MlflowException as e:
            raise KeyError(
                f"Failed to get model version '{name}:{version}' from the "
                f"MLflow model registry: {str(e)}"
            )
        # Return the model version.
        return self._cast_mlflow_version_to_model_version(
            mlflow_model_version=mlflow_model_version,
        )

    def list_model_versions(
        self,
        name: Optional[str] = None,
        model_source_uri: Optional[str] = None,
        metadata: Optional[ModelRegistryModelMetadata] = None,
        stage: Optional[ModelVersionStage] = None,
        count: Optional[int] = None,
        created_after: Optional[datetime] = None,
        created_before: Optional[datetime] = None,
        order_by_date: Optional[str] = None,
        **kwargs: Any,
    ) -> List[RegistryModelVersion]:
        """List model versions from the MLflow model registry.

        Args:
            name: The name of the model.
            model_source_uri: The model source URI.
            metadata: The metadata of the model version.
            stage: The stage of the model version.
            count: The maximum number of model versions to return.
            created_after: The minimum creation time of the model versions.
            created_before: The maximum creation time of the model versions.
            order_by_date: The order of the model versions by creation time,
                either ascending or descending.
            kwargs: Additional keyword arguments.

        Returns:
            The model versions.
        """
        # Set the filter string.
        filter_string = ""
        if name:
            filter_string += f"name='{name}'"
        if model_source_uri:
            if filter_string:
                filter_string += " AND "
            filter_string += f"source='{model_source_uri}'"
        if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
            if filter_string:
                filter_string += " AND "
            filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
        if metadata:
            for tag, value in metadata.model_dump().items():
                if value:
                    if filter_string:
                        filter_string += " AND "
                    filter_string += f"tags.{tag}='{value}'"
                    # Get the model versions.
        order_by = []
        if order_by_date:
            if order_by_date in ["asc", "desc"]:
                if order_by_date == "asc":
                    order_by = ["creation_timestamp ASC"]
                else:
                    order_by = ["creation_timestamp DESC"]
        mlflow_model_versions = self.mlflow_client.search_model_versions(
            filter_string=filter_string,
            order_by=order_by,
        )
        # Cast the MLflow model versions to the ZenML model version class.
        model_versions = []
        for mlflow_model_version in mlflow_model_versions:
            # check if given MlFlow model version matches the given request
            # before casting it
            if stage and not mlflow_model_version.current_stage == str(stage):
                continue
            if created_after and not (
                mlflow_model_version.creation_timestamp
                >= created_after.timestamp()
            ):
                continue
            if created_before and not (
                mlflow_model_version.creation_timestamp
                <= created_before.timestamp()
            ):
                continue
            try:
                model_versions.append(
                    self._cast_mlflow_version_to_model_version(
                        mlflow_model_version=mlflow_model_version,
                    )
                )
            except (AttributeError, OSError) as e:
                # Sometimes, the Model Registry in MLflow can become unusable
                # due to failed version registration or misuse. In such rare
                # cases, it's best to suppress those versions that are not usable.
                logger.warning(
                    "Error encountered while loading MLflow model version "
                    f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
                )
            if count and len(model_versions) == count:
                return model_versions

        return model_versions

    def load_model_version(
        self,
        name: str,
        version: str,
        **kwargs: Any,
    ) -> Any:
        """Load a model version from the MLflow model registry.

        This method loads the model version from the MLflow model registry
        and returns the model. The model is loaded using the `mlflow.pyfunc`
        module which takes care of loading the model from the model source
        URI for the right framework.

        Args:
            name: The name of the model.
            version: The version of the model.
            kwargs: Additional keyword arguments.

        Returns:
            The model version.

        Raises:
            KeyError: If the model version does not exist.
        """
        try:
            self.get_model_version(name=name, version=version)
        except KeyError:
            raise KeyError(
                f"Failed to load model version '{name}:{version}' from the "
                f"MLflow model registry: Model version does not exist."
            )
        # Load the model version.
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
        return load_model(
            model_uri=mlflow_model_version.source,
            **kwargs,
        )

    def get_model_uri_artifact_store(
        self,
        model_version: RegistryModelVersion,
    ) -> str:
        """Get the model URI artifact store.

        Args:
            model_version: The model version.

        Returns:
            The model URI artifact store.
        """
        artifact_store_path = (
            f"{Client().active_stack.artifact_store.path}/mlflow"
        )
        model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
        return artifact_store_path + model_source_uri

    def _cast_mlflow_version_to_model_version(
        self,
        mlflow_model_version: MLflowModelVersion,
    ) -> RegistryModelVersion:
        """Cast an MLflow model version to a model version.

        Args:
            mlflow_model_version: The MLflow model version.

        Returns:
            The model version.
        """
        metadata = mlflow_model_version.tags or {}
        if mlflow_model_version.run_id:
            metadata["mlflow_run_id"] = mlflow_model_version.run_id
        if mlflow_model_version.run_link:
            metadata["mlflow_run_link"] = mlflow_model_version.run_link

        try:
            from mlflow.models import get_model_info

            model_library = (
                get_model_info(model_uri=mlflow_model_version.source)
                .flavors.get("python_function", {})
                .get("loader_module")
            )
        except ImportError:
            model_library = None
        return RegistryModelVersion(
            registered_model=RegisteredModel(name=mlflow_model_version.name),
            model_format=MLFLOW_MODEL_FORMAT,
            model_library=model_library,
            version=str(mlflow_model_version.version),
            created_at=datetime.fromtimestamp(
                int(mlflow_model_version.creation_timestamp) / 1e3
            ),
            stage=ModelVersionStage(mlflow_model_version.current_stage),
            description=mlflow_model_version.description,
            last_updated_at=datetime.fromtimestamp(
                int(mlflow_model_version.last_updated_timestamp) / 1e3
            ),
            metadata=ModelRegistryModelMetadata(**metadata),
            model_source_uri=mlflow_model_version.source,
        )
config: MLFlowModelRegistryConfig property readonly

Returns the MLFlowModelRegistryConfig config.

Returns:

Type Description
MLFlowModelRegistryConfig

The configuration.

mlflow_client: mlflow.MlflowClient property readonly

Get the MLflow client.

Returns:

Type Description
mlflow.MlflowClient

The MLFlowClient.

validator: Optional[zenml.stack.stack_validator.StackValidator] property readonly

Validates that the stack contains an mlflow experiment tracker.

Returns:

Type Description
Optional[zenml.stack.stack_validator.StackValidator]

A StackValidator instance.

configure_mlflow(self)

Configures the MLflow Client with the experiment tracker config.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def configure_mlflow(self) -> None:
    """Configures the MLflow Client with the experiment tracker config."""
    experiment_tracker = Client().active_stack.experiment_tracker
    assert isinstance(experiment_tracker, MLFlowExperimentTracker)
    experiment_tracker.configure_mlflow()
delete_model(self, name)

Delete a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Exceptions:

Type Description
RuntimeError

If the model does not exist.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model(
    self,
    name: str,
) -> None:
    """Delete a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Raises:
        RuntimeError: If the model does not exist.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Delete the registered model.
    try:
        self.mlflow_client.delete_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model with name {name} from MLflow model "
            f"registry: {str(e)}",
        )
delete_model_version(self, name, version)

Delete a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Exceptions:

Type Description
RuntimeError

If mlflow fails to delete the model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def delete_model_version(
    self,
    name: str,
    version: str,
) -> None:
    """Delete a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        RuntimeError: If mlflow fails to delete the model version.
    """
    self.get_model_version(name=name, version=version)
    try:
        self.mlflow_client.delete_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model version '{version}' of model '{name}'."
            f"From the MLflow model registry: {str(e)}",
        )
get_model(self, name)

Get a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Returns:

Type Description
RegisteredModel

The model.

Exceptions:

Type Description
KeyError

If mlflow fails to get the model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model(self, name: str) -> RegisteredModel:
    """Get a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Returns:
        The model.

    Raises:
        KeyError: If mlflow fails to get the model.
    """
    # Get the registered model.
    try:
        registered_model = self.mlflow_client.get_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model with name {name} from the MLflow model "
            f"registry: {str(e)}",
        )
    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
get_model_uri_artifact_store(self, model_version)

Get the model URI artifact store.

Parameters:

Name Type Description Default
model_version RegistryModelVersion

The model version.

required

Returns:

Type Description
str

The model URI artifact store.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_uri_artifact_store(
    self,
    model_version: RegistryModelVersion,
) -> str:
    """Get the model URI artifact store.

    Args:
        model_version: The model version.

    Returns:
        The model URI artifact store.
    """
    artifact_store_path = (
        f"{Client().active_stack.artifact_store.path}/mlflow"
    )
    model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
    return artifact_store_path + model_source_uri
get_model_version(self, name, version)

Get a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Exceptions:

Type Description
KeyError

If the model version does not exist.

Returns:

Type Description
RegistryModelVersion

The model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def get_model_version(
    self,
    name: str,
    version: str,
) -> RegistryModelVersion:
    """Get a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        KeyError: If the model version does not exist.

    Returns:
        The model version.
    """
    # Get the model version from the MLflow model registry.
    try:
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model version '{name}:{version}' from the "
            f"MLflow model registry: {str(e)}"
        )
    # Return the model version.
    return self._cast_mlflow_version_to_model_version(
        mlflow_model_version=mlflow_model_version,
    )
list_model_versions(self, name=None, model_source_uri=None, metadata=None, stage=None, count=None, created_after=None, created_before=None, order_by_date=None, **kwargs)

List model versions from the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

The name of the model.

None
model_source_uri Optional[str]

The model source URI.

None
metadata Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata]

The metadata of the model version.

None
stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

The stage of the model version.

None
count Optional[int]

The maximum number of model versions to return.

None
created_after Optional[datetime.datetime]

The minimum creation time of the model versions.

None
created_before Optional[datetime.datetime]

The maximum creation time of the model versions.

None
order_by_date Optional[str]

The order of the model versions by creation time, either ascending or descending.

None
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[zenml.model_registries.base_model_registry.RegistryModelVersion]

The model versions.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_model_versions(
    self,
    name: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    stage: Optional[ModelVersionStage] = None,
    count: Optional[int] = None,
    created_after: Optional[datetime] = None,
    created_before: Optional[datetime] = None,
    order_by_date: Optional[str] = None,
    **kwargs: Any,
) -> List[RegistryModelVersion]:
    """List model versions from the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The model source URI.
        metadata: The metadata of the model version.
        stage: The stage of the model version.
        count: The maximum number of model versions to return.
        created_after: The minimum creation time of the model versions.
        created_before: The maximum creation time of the model versions.
        order_by_date: The order of the model versions by creation time,
            either ascending or descending.
        kwargs: Additional keyword arguments.

    Returns:
        The model versions.
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if model_source_uri:
        if filter_string:
            filter_string += " AND "
        filter_string += f"source='{model_source_uri}'"
    if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
        if filter_string:
            filter_string += " AND "
        filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
    if metadata:
        for tag, value in metadata.model_dump().items():
            if value:
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"
                # Get the model versions.
    order_by = []
    if order_by_date:
        if order_by_date in ["asc", "desc"]:
            if order_by_date == "asc":
                order_by = ["creation_timestamp ASC"]
            else:
                order_by = ["creation_timestamp DESC"]
    mlflow_model_versions = self.mlflow_client.search_model_versions(
        filter_string=filter_string,
        order_by=order_by,
    )
    # Cast the MLflow model versions to the ZenML model version class.
    model_versions = []
    for mlflow_model_version in mlflow_model_versions:
        # check if given MlFlow model version matches the given request
        # before casting it
        if stage and not mlflow_model_version.current_stage == str(stage):
            continue
        if created_after and not (
            mlflow_model_version.creation_timestamp
            >= created_after.timestamp()
        ):
            continue
        if created_before and not (
            mlflow_model_version.creation_timestamp
            <= created_before.timestamp()
        ):
            continue
        try:
            model_versions.append(
                self._cast_mlflow_version_to_model_version(
                    mlflow_model_version=mlflow_model_version,
                )
            )
        except (AttributeError, OSError) as e:
            # Sometimes, the Model Registry in MLflow can become unusable
            # due to failed version registration or misuse. In such rare
            # cases, it's best to suppress those versions that are not usable.
            logger.warning(
                "Error encountered while loading MLflow model version "
                f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
            )
        if count and len(model_versions) == count:
            return model_versions

    return model_versions
list_models(self, name=None, metadata=None)

List models in the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

A name to filter the models by.

None
metadata Optional[Dict[str, str]]

The metadata to filter the models by.

None

Returns:

Type Description
List[zenml.model_registries.base_model_registry.RegisteredModel]

A list of models (RegisteredModel)

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def list_models(
    self,
    name: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
    """List models in the MLflow model registry.

    Args:
        name: A name to filter the models by.
        metadata: The metadata to filter the models by.

    Returns:
        A list of models (RegisteredModel)
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if metadata:
        for tag, value in metadata.items():
            if filter_string:
                filter_string += " AND "
            filter_string += f"tags.{tag}='{value}'"

    # Get the registered models.
    registered_models = self.mlflow_client.search_registered_models(
        filter_string=filter_string,
        max_results=100,
    )
    # Return the registered models.
    return [
        RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )
        for registered_model in registered_models
    ]
load_model_version(self, name, version, **kwargs)

Load a model version from the MLflow model registry.

This method loads the model version from the MLflow model registry and returns the model. The model is loaded using the mlflow.pyfunc module which takes care of loading the model from the model source URI for the right framework.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Any

The model version.

Exceptions:

Type Description
KeyError

If the model version does not exist.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def load_model_version(
    self,
    name: str,
    version: str,
    **kwargs: Any,
) -> Any:
    """Load a model version from the MLflow model registry.

    This method loads the model version from the MLflow model registry
    and returns the model. The model is loaded using the `mlflow.pyfunc`
    module which takes care of loading the model from the model source
    URI for the right framework.

    Args:
        name: The name of the model.
        version: The version of the model.
        kwargs: Additional keyword arguments.

    Returns:
        The model version.

    Raises:
        KeyError: If the model version does not exist.
    """
    try:
        self.get_model_version(name=name, version=version)
    except KeyError:
        raise KeyError(
            f"Failed to load model version '{name}:{version}' from the "
            f"MLflow model registry: Model version does not exist."
        )
    # Load the model version.
    mlflow_model_version = self.mlflow_client.get_model_version(
        name=name,
        version=version,
    )
    return load_model(
        model_uri=mlflow_model_version.source,
        **kwargs,
    )
register_model(self, name, description=None, metadata=None)

Register a model to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None

Exceptions:

Type Description
RuntimeError

If the model already exists.

Returns:

Type Description
RegisteredModel

The registered model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
    """Register a model to the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.

    Raises:
        RuntimeError: If the model already exists.

    Returns:
        The registered model.
    """
    # Check if model already exists.
    try:
        self.get_model(name)
        raise KeyError(
            f"Model with name {name} already exists in the MLflow model "
            f"registry. Please use a different name.",
        )
    except KeyError:
        pass
    # Register model.
    try:
        registered_model = self.mlflow_client.create_registered_model(
            name=name,
            description=description,
            tags=metadata,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model with name {name} to the MLflow "
            f"model registry: {str(e)}",
        )

    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
register_model_version(self, name, version=None, model_source_uri=None, description=None, metadata=None, **kwargs)

Register a model version to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
model_source_uri Optional[str]

The source URI of the model.

None
version Optional[str]

The version of the model.

None
description Optional[str]

The description of the model version.

None
metadata Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata]

The registry metadata of the model version.

None
**kwargs Any

Additional keyword arguments.

{}

Exceptions:

Type Description
RuntimeError

If the registered model does not exist.

Returns:

Type Description
RegistryModelVersion

The registered model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def register_model_version(
    self,
    name: str,
    version: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    **kwargs: Any,
) -> RegistryModelVersion:
    """Register a model version to the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The source URI of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The registry metadata of the model version.
        **kwargs: Additional keyword arguments.

    Raises:
        RuntimeError: If the registered model does not exist.

    Returns:
        The registered model version.
    """
    # Check if the model exists, if not create it.
    try:
        self.get_model(name=name)
    except KeyError:
        logger.info(
            f"No registered model with name {name} found. Creating a new"
            "registered model."
        )
        self.register_model(
            name=name,
        )
    try:
        # Inform the user that the version is ignored.
        if version:
            logger.info(
                f"MLflow model registry does not take a version as an argument. "
                f"Registering a new version for the model `'{name}'` "
                f"a version will be assigned automatically."
            )
        metadata_dict = metadata.model_dump() if metadata else {}
        # Set the run ID and link.
        run_id = metadata_dict.get("mlflow_run_id", None)
        run_link = metadata_dict.get("mlflow_run_link", None)
        # Register the model version.
        registered_model_version = self.mlflow_client.create_model_version(
            name=name,
            source=model_source_uri,
            run_id=run_id,
            run_link=run_link,
            description=description,
            tags=metadata_dict,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model version with name '{name}' and "
            f"version '{version}' to the MLflow model registry."
            f"Error: {e}"
        )
    # Return the registered model version.
    return self._cast_mlflow_version_to_model_version(
        registered_model_version
    )
update_model(self, name, description=None, metadata=None, remove_metadata=None)

Update a model in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model.

None

Exceptions:

Type Description
RuntimeError

If mlflow fails to update the model.

Returns:

Type Description
RegisteredModel

The updated model.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
    remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
    """Update a model in the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.
        remove_metadata: The metadata to remove from the model.

    Raises:
        RuntimeError: If mlflow fails to update the model.

    Returns:
        The updated model.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Update the registered model description.
    if description:
        try:
            self.mlflow_client.update_registered_model(
                name=name,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update description for the model {name} in MLflow "
                f"model registry: {str(e)}",
            )
    # Update the registered model tags.
    if metadata:
        try:
            for tag, value in metadata.items():
                self.mlflow_client.set_registered_model_tag(
                    name=name,
                    key=tag,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Remove tags from the registered model.
    if remove_metadata:
        try:
            for tag in remove_metadata:
                self.mlflow_client.delete_registered_model_tag(
                    name=name,
                    key=tag,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Return the updated registered model.
    return self.get_model(name)
update_model_version(self, name, version, description=None, metadata=None, remove_metadata=None, stage=None)

Update a model version in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
description Optional[str]

The description of the model version.

None
metadata Optional[zenml.model_registries.base_model_registry.ModelRegistryModelMetadata]

The metadata of the model version.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model version.

None
stage Optional[zenml.model_registries.base_model_registry.ModelVersionStage]

The stage of the model version.

None

Exceptions:

Type Description
RuntimeError

If mlflow fails to update the model version.

Returns:

Type Description
RegistryModelVersion

The updated model version.

Source code in zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
def update_model_version(
    self,
    name: str,
    version: str,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    remove_metadata: Optional[List[str]] = None,
    stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
    """Update a model version in the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The metadata of the model version.
        remove_metadata: The metadata to remove from the model version.
        stage: The stage of the model version.

    Raises:
        RuntimeError: If mlflow fails to update the model version.

    Returns:
        The updated model version.
    """
    self.get_model_version(name=name, version=version)
    # Update the model description.
    if description:
        try:
            self.mlflow_client.update_model_version(
                name=name,
                version=version,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the description of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model tags.
    if metadata:
        try:
            for key, value in metadata.model_dump().items():
                self.mlflow_client.set_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Remove the model tags.
    if remove_metadata:
        try:
            for key in remove_metadata:
                self.mlflow_client.delete_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model stage.
    if stage:
        try:
            self.mlflow_client.transition_model_version_stage(
                name=name,
                version=version,
                stage=stage.value,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the current stage of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    return self.get_model_version(name, version)

services special

Initialization of the MLflow Service.

mlflow_deployment

Implementation of the MLflow deployment functionality.

MLFlowDeploymentConfig (LocalDaemonServiceConfig)

MLflow model deployment configuration.

Attributes:

Name Type Description
model_uri str

URI of the MLflow model to serve

model_name str

the name of the model

workers int

number of workers to use for the prediction service

registry_model_name Optional[str]

the name of the model in the registry

registry_model_version Optional[str]

the version of the model in the registry

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

timeout in seconds for starting and stopping the service

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentConfig(LocalDaemonServiceConfig):
    """MLflow model deployment configuration.

    Attributes:
        model_uri: URI of the MLflow model to serve
        model_name: the name of the model
        workers: number of workers to use for the prediction service
        registry_model_name: the name of the model in the registry
        registry_model_version: the version of the model in the registry
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: timeout in seconds for starting and stopping the service
    """

    model_uri: str
    model_name: str
    registry_model_name: Optional[str] = None
    registry_model_version: Optional[str] = None
    registry_model_stage: Optional[str] = None
    workers: int = 1
    mlserver: bool = False
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT
MLFlowDeploymentEndpoint (LocalDaemonServiceEndpoint)

A service endpoint exposed by the MLflow deployment daemon.

Attributes:

Name Type Description
config MLFlowDeploymentEndpointConfig

service endpoint configuration

monitor HTTPEndpointHealthMonitor

optional service endpoint health monitor

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint):
    """A service endpoint exposed by the MLflow deployment daemon.

    Attributes:
        config: service endpoint configuration
        monitor: optional service endpoint health monitor
    """

    config: MLFlowDeploymentEndpointConfig
    monitor: HTTPEndpointHealthMonitor

    @property
    def prediction_url(self) -> Optional[str]:
        """Gets the prediction URL for the endpoint.

        Returns:
            the prediction URL for the endpoint
        """
        uri = self.status.uri
        if not uri:
            return None
        return os.path.join(uri, self.config.prediction_url_path)
prediction_url: Optional[str] property readonly

Gets the prediction URL for the endpoint.

Returns:

Type Description
Optional[str]

the prediction URL for the endpoint

MLFlowDeploymentEndpointConfig (LocalDaemonServiceEndpointConfig)

MLflow daemon service endpoint configuration.

Attributes:

Name Type Description
prediction_url_path str

URI subpath for prediction requests

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentEndpointConfig(LocalDaemonServiceEndpointConfig):
    """MLflow daemon service endpoint configuration.

    Attributes:
        prediction_url_path: URI subpath for prediction requests
    """

    prediction_url_path: str
MLFlowDeploymentService (LocalDaemonService, BaseDeploymentService)

MLflow deployment service used to start a local prediction server for MLflow models.

Attributes:

Name Type Description
SERVICE_TYPE ClassVar[zenml.services.service_type.ServiceType]

a service type descriptor with information describing the MLflow deployment service class

config MLFlowDeploymentConfig

service configuration

endpoint MLFlowDeploymentEndpoint

optional service endpoint

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService):
    """MLflow deployment service used to start a local prediction server for MLflow models.

    Attributes:
        SERVICE_TYPE: a service type descriptor with information describing
            the MLflow deployment service class
        config: service configuration
        endpoint: optional service endpoint
    """

    SERVICE_TYPE = ServiceType(
        name="mlflow-deployment",
        type="model-serving",
        flavor="mlflow",
        description="MLflow prediction service",
        logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png",
    )

    config: MLFlowDeploymentConfig
    endpoint: MLFlowDeploymentEndpoint

    def __init__(
        self,
        config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
        **attrs: Any,
    ) -> None:
        """Initialize the MLflow deployment service.

        Args:
            config: service configuration
            attrs: additional attributes to set on the service
        """
        # ensure that the endpoint is created before the service is initialized
        # TODO [ENG-700]: implement a service factory or builder for MLflow
        #   deployment services
        if (
            isinstance(config, MLFlowDeploymentConfig)
            and "endpoint" not in attrs
        ):
            if config.mlserver:
                prediction_url_path = MLSERVER_PREDICTION_URL_PATH
                healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
                use_head_request = False
            else:
                prediction_url_path = MLFLOW_PREDICTION_URL_PATH
                healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
                use_head_request = True

            endpoint = MLFlowDeploymentEndpoint(
                config=MLFlowDeploymentEndpointConfig(
                    protocol=ServiceEndpointProtocol.HTTP,
                    prediction_url_path=prediction_url_path,
                ),
                monitor=HTTPEndpointHealthMonitor(
                    config=HTTPEndpointHealthMonitorConfig(
                        healthcheck_uri_path=healthcheck_uri_path,
                        use_head_request=use_head_request,
                    )
                ),
            )
            attrs["endpoint"] = endpoint
        super().__init__(config=config, **attrs)

    def run(self) -> None:
        """Start the service.

        Raises:
            ValueError: if the active stack doesn't have an MLflow experiment
                tracker
        """
        logger.info(
            "Starting MLflow prediction service as blocking "
            "process... press CTRL+C once to stop it."
        )

        self.endpoint.prepare_for_start()
        try:
            backend_kwargs: Dict[str, Any] = {}
            serve_kwargs: Dict[str, Any] = {}
            mlflow_version = MLFLOW_VERSION.split(".")
            # MLflow version 1.26 introduces an additional mandatory
            # `timeout` argument to the `PyFuncBackend.serve` function
            if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
                serve_kwargs["timeout"] = None
            # Mlflow 2.0+ requires the env_manager to be set to "local"
            # to run the deploy the model on the local running environment
            if int(mlflow_version[0]) >= 2:
                backend_kwargs["env_manager"] = "local"
            backend = PyFuncBackend(
                config={},
                no_conda=True,
                workers=self.config.workers,
                install_mlflow=False,
                **backend_kwargs,
            )
            experiment_tracker = Client().active_stack.experiment_tracker
            if not isinstance(experiment_tracker, MLFlowExperimentTracker):
                raise ValueError(
                    "MLflow model deployer step requires an MLflow experiment "
                    "tracker. Please add an MLflow experiment tracker to your "
                    "stack."
                )
            experiment_tracker.configure_mlflow()
            backend.serve(
                model_uri=self.config.model_uri,
                port=self.endpoint.status.port,
                host="localhost",
                enable_mlserver=self.config.mlserver,
                **serve_kwargs,
            )
        except KeyboardInterrupt:
            logger.info(
                "MLflow prediction service stopped. Resuming normal execution."
            )

    @property
    def prediction_url(self) -> Optional[str]:
        """Get the URI where the prediction service is answering requests.

        Returns:
            The URI where the prediction service can be contacted to process
            HTTP/REST inference requests, or None, if the service isn't running.
        """
        if not self.is_running:
            return None
        return self.endpoint.prediction_url

    def predict(
        self, request: Union["NDArray[Any]", pd.DataFrame]
    ) -> "NDArray[Any]":
        """Make a prediction using the service.

        Args:
            request: a Numpy Array or Pandas DataFrame representing the request

        Returns:
            A numpy array representing the prediction returned by the service.

        Raises:
            Exception: if the service is not running
            ValueError: if the prediction endpoint is unknown.
        """
        if not self.is_running:
            raise Exception(
                "MLflow prediction service is not running. "
                "Please start the service before making predictions."
            )

        if self.endpoint.prediction_url is not None:
            if type(request) == pd.DataFrame:
                response = requests.post(  # nosec
                    self.endpoint.prediction_url,
                    json={"instances": request.to_dict("records")},
                )
            else:
                response = requests.post(  # nosec
                    self.endpoint.prediction_url,
                    json={"instances": request.tolist()},
                )
        else:
            raise ValueError("No endpoint known for prediction.")
        response.raise_for_status()
        if int(MLFLOW_VERSION.split(".")[0]) <= 1:
            return np.array(response.json())
        else:
            # Mlflow 2.0+ returns a dictionary with the predictions
            # under the "predictions" key
            return np.array(response.json()["predictions"])
prediction_url: Optional[str] property readonly

Get the URI where the prediction service is answering requests.

Returns:

Type Description
Optional[str]

The URI where the prediction service can be contacted to process HTTP/REST inference requests, or None, if the service isn't running.

__init__(self, config, **attrs) special

Initialize the MLflow deployment service.

Parameters:

Name Type Description Default
config Union[zenml.integrations.mlflow.services.mlflow_deployment.MLFlowDeploymentConfig, Dict[str, Any]]

service configuration

required
attrs Any

additional attributes to set on the service

{}
Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def __init__(
    self,
    config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialize the MLflow deployment service.

    Args:
        config: service configuration
        attrs: additional attributes to set on the service
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-700]: implement a service factory or builder for MLflow
    #   deployment services
    if (
        isinstance(config, MLFlowDeploymentConfig)
        and "endpoint" not in attrs
    ):
        if config.mlserver:
            prediction_url_path = MLSERVER_PREDICTION_URL_PATH
            healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
            use_head_request = False
        else:
            prediction_url_path = MLFLOW_PREDICTION_URL_PATH
            healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
            use_head_request = True

        endpoint = MLFlowDeploymentEndpoint(
            config=MLFlowDeploymentEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                prediction_url_path=prediction_url_path,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path=healthcheck_uri_path,
                    use_head_request=use_head_request,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
predict(self, request)

Make a prediction using the service.

Parameters:

Name Type Description Default
request Union[NDArray[Any], pandas.core.frame.DataFrame]

a Numpy Array or Pandas DataFrame representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Exceptions:

Type Description
Exception

if the service is not running

ValueError

if the prediction endpoint is unknown.

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def predict(
    self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a Numpy Array or Pandas DataFrame representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not running
        ValueError: if the prediction endpoint is unknown.
    """
    if not self.is_running:
        raise Exception(
            "MLflow prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.endpoint.prediction_url is not None:
        if type(request) == pd.DataFrame:
            response = requests.post(  # nosec
                self.endpoint.prediction_url,
                json={"instances": request.to_dict("records")},
            )
        else:
            response = requests.post(  # nosec
                self.endpoint.prediction_url,
                json={"instances": request.tolist()},
            )
    else:
        raise ValueError("No endpoint known for prediction.")
    response.raise_for_status()
    if int(MLFLOW_VERSION.split(".")[0]) <= 1:
        return np.array(response.json())
    else:
        # Mlflow 2.0+ returns a dictionary with the predictions
        # under the "predictions" key
        return np.array(response.json()["predictions"])
run(self)

Start the service.

Exceptions:

Type Description
ValueError

if the active stack doesn't have an MLflow experiment tracker

Source code in zenml/integrations/mlflow/services/mlflow_deployment.py
def run(self) -> None:
    """Start the service.

    Raises:
        ValueError: if the active stack doesn't have an MLflow experiment
            tracker
    """
    logger.info(
        "Starting MLflow prediction service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()
    try:
        backend_kwargs: Dict[str, Any] = {}
        serve_kwargs: Dict[str, Any] = {}
        mlflow_version = MLFLOW_VERSION.split(".")
        # MLflow version 1.26 introduces an additional mandatory
        # `timeout` argument to the `PyFuncBackend.serve` function
        if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
            serve_kwargs["timeout"] = None
        # Mlflow 2.0+ requires the env_manager to be set to "local"
        # to run the deploy the model on the local running environment
        if int(mlflow_version[0]) >= 2:
            backend_kwargs["env_manager"] = "local"
        backend = PyFuncBackend(
            config={},
            no_conda=True,
            workers=self.config.workers,
            install_mlflow=False,
            **backend_kwargs,
        )
        experiment_tracker = Client().active_stack.experiment_tracker
        if not isinstance(experiment_tracker, MLFlowExperimentTracker):
            raise ValueError(
                "MLflow model deployer step requires an MLflow experiment "
                "tracker. Please add an MLflow experiment tracker to your "
                "stack."
            )
        experiment_tracker.configure_mlflow()
        backend.serve(
            model_uri=self.config.model_uri,
            port=self.endpoint.status.port,
            host="localhost",
            enable_mlserver=self.config.mlserver,
            **serve_kwargs,
        )
    except KeyboardInterrupt:
        logger.info(
            "MLflow prediction service stopped. Resuming normal execution."
        )

steps special

Initialization of the MLflow standard interface steps.

mlflow_deployer

Implementation of the MLflow model deployer pipeline step.

mlflow_registry

Implementation of the MLflow model registration pipeline step.