Skip to content

Mlflow

zenml.integrations.mlflow

Initialization for the ZenML MLflow integration.

The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI.

Attributes

MLFLOW = 'mlflow' module-attribute

MLFLOW_MODEL_DEPLOYER_FLAVOR = 'mlflow' module-attribute

MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR = 'mlflow' module-attribute

MLFLOW_MODEL_REGISTRY_FLAVOR = 'mlflow' module-attribute

logger = get_logger(__name__) module-attribute

Classes

Flavor

Class for ZenML Flavors.

Attributes
config_class: Type[StackComponentConfig] abstractmethod property

Returns StackComponentConfig config class.

Returns:

Type Description
Type[StackComponentConfig]

The config class.

config_schema: Dict[str, Any] property

The config schema for a flavor.

Returns:

Type Description
Dict[str, Any]

The config schema.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[StackComponent] abstractmethod property

Implementation class for this flavor.

Returns:

Type Description
Type[StackComponent]

The implementation class for this flavor.

logo_url: Optional[str] property

A url to represent the flavor in the dashboard.

Returns:

Type Description
Optional[str]

The flavor logo.

name: str abstractmethod property

The flavor name.

Returns:

Type Description
str

The flavor name.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

service_connector_requirements: Optional[ServiceConnectorRequirements] property

Service connector resource requirements for service connectors.

Specifies resource requirements that are used to filter the available service connector types that are compatible with this flavor.

Returns:

Type Description
Optional[ServiceConnectorRequirements]

Requirements for compatible service connectors, if a service

Optional[ServiceConnectorRequirements]

connector is required for this flavor.

type: StackComponentType abstractmethod property

The stack component type.

Returns:

Type Description
StackComponentType

The stack component type.

Functions
from_model(flavor_model: FlavorResponse) -> Flavor classmethod

Loads a flavor from a model.

Parameters:

Name Type Description Default
flavor_model FlavorResponse

The model to load from.

required

Raises:

Type Description
CustomFlavorImportError

If the custom flavor can't be imported.

ImportError

If the flavor can't be imported.

Returns:

Type Description
Flavor

The loaded flavor.

Source code in src/zenml/stack/flavor.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@classmethod
def from_model(cls, flavor_model: FlavorResponse) -> "Flavor":
    """Loads a flavor from a model.

    Args:
        flavor_model: The model to load from.

    Raises:
        CustomFlavorImportError: If the custom flavor can't be imported.
        ImportError: If the flavor can't be imported.

    Returns:
        The loaded flavor.
    """
    try:
        flavor = source_utils.load(flavor_model.source)()
    except (ModuleNotFoundError, ImportError, NotImplementedError) as err:
        if flavor_model.is_custom:
            flavor_module, _ = flavor_model.source.rsplit(".", maxsplit=1)
            expected_file_path = os.path.join(
                source_utils.get_source_root(),
                flavor_module.replace(".", os.path.sep),
            )
            raise CustomFlavorImportError(
                f"Couldn't import custom flavor {flavor_model.name}: "
                f"{err}. Make sure the custom flavor class "
                f"`{flavor_model.source}` is importable. If it is part of "
                "a library, make sure it is installed. If "
                "it is a local code file, make sure it exists at "
                f"`{expected_file_path}.py`."
            )
        else:
            raise ImportError(
                f"Couldn't import flavor {flavor_model.name}: {err}"
            )
    return cast(Flavor, flavor)
generate_default_docs_url() -> str

Generate the doc urls for all inbuilt and integration flavors.

Note that this method is not going to be useful for custom flavors, which do not have any docs in the main zenml docs.

Returns:

Type Description
str

The complete url to the zenml documentation

Source code in src/zenml/stack/flavor.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
def generate_default_docs_url(self) -> str:
    """Generate the doc urls for all inbuilt and integration flavors.

    Note that this method is not going to be useful for custom flavors,
    which do not have any docs in the main zenml docs.

    Returns:
        The complete url to the zenml documentation
    """
    from zenml import __version__

    component_type = self.type.plural.replace("_", "-")
    name = self.name.replace("_", "-")

    try:
        is_latest = is_latest_zenml_version()
    except RuntimeError:
        # We assume in error cases that we are on the latest version
        is_latest = True

    if is_latest:
        base = "https://docs.zenml.io"
    else:
        base = f"https://zenml-io.gitbook.io/zenml-legacy-documentation/v/{__version__}"
    return f"{base}/stack-components/{component_type}/{name}"
generate_default_sdk_docs_url() -> str

Generate SDK docs url for a flavor.

Returns:

Type Description
str

The complete url to the zenml SDK docs

Source code in src/zenml/stack/flavor.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def generate_default_sdk_docs_url(self) -> str:
    """Generate SDK docs url for a flavor.

    Returns:
        The complete url to the zenml SDK docs
    """
    from zenml import __version__

    base = f"https://sdkdocs.zenml.io/{__version__}"

    component_type = self.type.plural

    if "zenml.integrations" in self.__module__:
        # Get integration name out of module path which will look something
        #  like this "zenml.integrations.<integration>....
        integration = self.__module__.split(
            "zenml.integrations.", maxsplit=1
        )[1].split(".")[0]

        return (
            f"{base}/integration_code_docs"
            f"/integrations-{integration}/#{self.__module__}"
        )

    else:
        return (
            f"{base}/core_code_docs/core-{component_type}/"
            f"#{self.__module__}"
        )
to_model(integration: Optional[str] = None, is_custom: bool = True) -> FlavorRequest

Converts a flavor to a model.

Parameters:

Name Type Description Default
integration Optional[str]

The integration to use for the model.

None
is_custom bool

Whether the flavor is a custom flavor.

True

Returns:

Type Description
FlavorRequest

The model.

Source code in src/zenml/stack/flavor.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def to_model(
    self,
    integration: Optional[str] = None,
    is_custom: bool = True,
) -> FlavorRequest:
    """Converts a flavor to a model.

    Args:
        integration: The integration to use for the model.
        is_custom: Whether the flavor is a custom flavor.

    Returns:
        The model.
    """
    connector_requirements = self.service_connector_requirements
    connector_type = (
        connector_requirements.connector_type
        if connector_requirements
        else None
    )
    resource_type = (
        connector_requirements.resource_type
        if connector_requirements
        else None
    )
    resource_id_attr = (
        connector_requirements.resource_id_attr
        if connector_requirements
        else None
    )

    model = FlavorRequest(
        name=self.name,
        type=self.type,
        source=source_utils.resolve(self.__class__).import_path,
        config_schema=self.config_schema,
        connector_type=connector_type,
        connector_resource_type=resource_type,
        connector_resource_id_attr=resource_id_attr,
        integration=integration,
        logo_url=self.logo_url,
        docs_url=self.docs_url,
        sdk_docs_url=self.sdk_docs_url,
        is_custom=is_custom,
    )
    return model

Integration

Base class for integration in ZenML.

Functions
activate() -> None classmethod

Abstract method to activate the integration.

Source code in src/zenml/integrations/integration.py
175
176
177
@classmethod
def activate(cls) -> None:
    """Abstract method to activate the integration."""
check_installation() -> bool classmethod

Method to check whether the required packages are installed.

Returns:

Type Description
bool

True if all required packages are installed, False otherwise.

Source code in src/zenml/integrations/integration.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@classmethod
def check_installation(cls) -> bool:
    """Method to check whether the required packages are installed.

    Returns:
        True if all required packages are installed, False otherwise.
    """
    for r in cls.get_requirements():
        try:
            # First check if the base package is installed
            dist = pkg_resources.get_distribution(r)

            # Next, check if the dependencies (including extras) are
            # installed
            deps: List[Requirement] = []

            _, extras = parse_requirement(r)
            if extras:
                extra_list = extras[1:-1].split(",")
                for extra in extra_list:
                    try:
                        requirements = dist.requires(extras=[extra])  # type: ignore[arg-type]
                    except pkg_resources.UnknownExtra as e:
                        logger.debug(f"Unknown extra: {str(e)}")
                        return False
                    deps.extend(requirements)
            else:
                deps = dist.requires()

            for ri in deps:
                try:
                    # Remove the "extra == ..." part from the requirement string
                    cleaned_req = re.sub(
                        r"; extra == \"\w+\"", "", str(ri)
                    )
                    pkg_resources.get_distribution(cleaned_req)
                except pkg_resources.DistributionNotFound as e:
                    logger.debug(
                        f"Unable to find required dependency "
                        f"'{e.req}' for requirement '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False
                except pkg_resources.VersionConflict as e:
                    logger.debug(
                        f"Package version '{e.dist}' does not match "
                        f"version '{e.req}' required by '{r}' "
                        f"necessary for integration '{cls.NAME}'."
                    )
                    return False

        except pkg_resources.DistributionNotFound as e:
            logger.debug(
                f"Unable to find required package '{e.req}' for "
                f"integration {cls.NAME}."
            )
            return False
        except pkg_resources.VersionConflict as e:
            logger.debug(
                f"Package version '{e.dist}' does not match version "
                f"'{e.req}' necessary for integration {cls.NAME}."
            )
            return False

    logger.debug(
        f"Integration {cls.NAME} is installed correctly with "
        f"requirements {cls.get_requirements()}."
    )
    return True
flavors() -> List[Type[Flavor]] classmethod

Abstract method to declare new stack component flavors.

Returns:

Type Description
List[Type[Flavor]]

A list of new stack component flavors.

Source code in src/zenml/integrations/integration.py
179
180
181
182
183
184
185
186
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Abstract method to declare new stack component flavors.

    Returns:
        A list of new stack component flavors.
    """
    return []
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@classmethod
def get_requirements(
    cls,
    target_os: Optional[str] = None,
    python_version: Optional[str] = None,
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    return cls.REQUIREMENTS
get_uninstall_requirements(target_os: Optional[str] = None) -> List[str] classmethod

Method to get the uninstall requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/integration.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
@classmethod
def get_uninstall_requirements(
    cls, target_os: Optional[str] = None
) -> List[str]:
    """Method to get the uninstall requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.

    Returns:
        A list of requirements.
    """
    ret = []
    for each in cls.get_requirements(target_os=target_os):
        is_ignored = False
        for ignored in cls.REQUIREMENTS_IGNORED_ON_UNINSTALL:
            if each.startswith(ignored):
                is_ignored = True
                break
        if not is_ignored:
            ret.append(each)
    return ret
plugin_flavors() -> List[Type[BasePluginFlavor]] classmethod

Abstract method to declare new plugin flavors.

Returns:

Type Description
List[Type[BasePluginFlavor]]

A list of new plugin flavors.

Source code in src/zenml/integrations/integration.py
188
189
190
191
192
193
194
195
@classmethod
def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]:
    """Abstract method to declare new plugin flavors.

    Returns:
        A list of new plugin flavors.
    """
    return []

MlflowIntegration

Bases: Integration

Definition of MLflow integration for ZenML.

Functions
activate() -> None classmethod

Activate the MLflow integration.

Source code in src/zenml/integrations/mlflow/__init__.py
74
75
76
77
@classmethod
def activate(cls) -> None:
    """Activate the MLflow integration."""
    from zenml.integrations.mlflow import services  # noqa
flavors() -> List[Type[Flavor]] classmethod

Declare the stack component flavors for the MLflow integration.

Returns:

Type Description
List[Type[Flavor]]

List of stack component flavors for this integration.

Source code in src/zenml/integrations/mlflow/__init__.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@classmethod
def flavors(cls) -> List[Type[Flavor]]:
    """Declare the stack component flavors for the MLflow integration.

    Returns:
        List of stack component flavors for this integration.
    """
    from zenml.integrations.mlflow.flavors import (
        MLFlowExperimentTrackerFlavor,
        MLFlowModelDeployerFlavor,
        MLFlowModelRegistryFlavor,
    )

    return [
        MLFlowModelDeployerFlavor,
        MLFlowExperimentTrackerFlavor,
        MLFlowModelRegistryFlavor,
    ]
get_requirements(target_os: Optional[str] = None, python_version: Optional[str] = None) -> List[str] classmethod

Method to get the requirements for the integration.

Parameters:

Name Type Description Default
target_os Optional[str]

The target operating system to get the requirements for.

None
python_version Optional[str]

The Python version to use for the requirements.

None

Returns:

Type Description
List[str]

A list of requirements.

Source code in src/zenml/integrations/mlflow/__init__.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@classmethod
def get_requirements(
    cls, target_os: Optional[str] = None, python_version: Optional[str] = None
) -> List[str]:
    """Method to get the requirements for the integration.

    Args:
        target_os: The target operating system to get the requirements for.
        python_version: The Python version to use for the requirements.

    Returns:
        A list of requirements.
    """
    from zenml.integrations.numpy import NumpyIntegration
    from zenml.integrations.pandas import PandasIntegration

    reqs = [
        "mlflow>=2.1.1,<3",
        # TODO: remove this requirement once rapidjson is fixed
        "python-rapidjson<1.15",
    ]

    reqs.extend(NumpyIntegration.get_requirements(target_os=target_os, python_version=python_version))
    reqs.extend(PandasIntegration.get_requirements(target_os=target_os, python_version=python_version))
    return reqs

Functions

get_logger(logger_name: str) -> logging.Logger

Main function to get logger name,.

Parameters:

Name Type Description Default
logger_name str

Name of logger to initialize.

required

Returns:

Type Description
Logger

A logger object.

Source code in src/zenml/logger.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def get_logger(logger_name: str) -> logging.Logger:
    """Main function to get logger name,.

    Args:
        logger_name: Name of logger to initialize.

    Returns:
        A logger object.
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(get_logging_level().value)
    logger.addHandler(get_console_handler())

    logger.propagate = False
    return logger

Modules

experiment_trackers

Initialization of the MLflow experiment tracker.

Classes
MLFlowExperimentTracker(*args: Any, **kwargs: Any)

Bases: BaseExperimentTracker

Track experiments using MLflow.

Initialize the experiment tracker and validate the tracking uri.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
66
67
68
69
70
71
72
73
74
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the experiment tracker and validate the tracking uri.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self._ensure_valid_tracking_uri()
Attributes
config: MLFlowExperimentTrackerConfig property

Returns the MLFlowExperimentTrackerConfig config.

Returns:

Type Description
MLFlowExperimentTrackerConfig

The configuration.

local_path: Optional[str] property

Path to the local directory where the MLflow artifacts are stored.

Returns:

Type Description
Optional[str]

None if configured with a remote tracking URI, otherwise the

Optional[str]

path to the local MLflow artifact store directory.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the Mlflow experiment tracker.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

Functions
cleanup_step_run(info: StepRunInfo, step_failed: bool) -> None

Stops active MLflow runs and resets the MLflow tracking uri.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required
step_failed bool

Whether the step failed or not.

required
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def cleanup_step_run(
    self,
    info: "StepRunInfo",
    step_failed: bool,
) -> None:
    """Stops active MLflow runs and resets the MLflow tracking uri.

    Args:
        info: Info about the step that was executed.
        step_failed: Whether the step failed or not.
    """
    status = "FAILED" if step_failed else "FINISHED"
    self.disable_autologging()
    mlflow_utils.stop_zenml_mlflow_runs(status)
    mlflow.set_tracking_uri("")
configure_mlflow() -> None

Configures the MLflow tracking URI and any additional credentials.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    tracking_uri = self.get_tracking_uri()
    mlflow.set_tracking_uri(tracking_uri)

    if is_databricks_tracking_uri(tracking_uri):
        if self.config.databricks_host:
            os.environ[DATABRICKS_HOST] = self.config.databricks_host
        if self.config.tracking_username:
            os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
        if self.config.tracking_password:
            os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
        if self.config.tracking_token:
            os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
        if self.config.enable_unity_catalog:
            mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG)
    else:
        os.environ[MLFLOW_TRACKING_URI] = tracking_uri
        if self.config.tracking_username:
            os.environ[MLFLOW_TRACKING_USERNAME] = (
                self.config.tracking_username
            )
        if self.config.tracking_password:
            os.environ[MLFLOW_TRACKING_PASSWORD] = (
                self.config.tracking_password
            )
        if self.config.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.config.tracking_insecure_tls else "false"
    )
disable_autologging() -> None

Disables MLflow autologging for all supported frameworks.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def disable_autologging(self) -> None:
    """Disables MLflow autologging for all supported frameworks."""
    frameworks = [
        "tensorflow",
        "xgboost",
        "lightgbm",
        "statsmodels",
        "spark",
        "sklearn",
        "fastai",
        "pytorch",
    ]

    failed_frameworks = []

    for framework in frameworks:
        try:
            # Correctly prefix the module name with 'mlflow.'
            module_name = f"mlflow.{framework}"
            # Dynamically import the module corresponding to the framework
            module = importlib.import_module(module_name)
            # Call the autolog function with disable=True
            module.autolog(disable=True)
        except ImportError as e:
            # only log on mlflow relevant errors
            if "mlflow" in e.msg.lower():
                failed_frameworks.append(framework)
        except Exception:
            failed_frameworks.append(framework)

    if len(failed_frameworks) > 0:
        logger.warning(
            f"Failed to disable MLflow autologging for the following frameworks: "
            f"{failed_frameworks}."
        )
get_run_id(experiment_name: str, run_name: str) -> Optional[str]

Gets the if of a run with the given name and experiment.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment in which to search for the run.

required
run_name str

Name of the run to search.

required

Returns:

Type Description
Optional[str]

The id of the run if it exists.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
    """Gets the if of a run with the given name and experiment.

    Args:
        experiment_name: Name of the experiment in which to search for the
            run.
        run_name: Name of the run to search.

    Returns:
        The id of the run if it exists.
    """
    self.configure_mlflow()
    experiment_name = self._adjust_experiment_name(experiment_name)

    runs = mlflow.search_runs(
        experiment_names=[experiment_name],
        filter_string=f'tags.mlflow.runName = "{run_name}"',
        run_view_type=3,
        output_format="list",
    )
    if not runs:
        return None

    run: Run = runs[0]
    if mlflow_utils.is_zenml_run(run):
        return cast(str, run.info.run_id)
    else:
        return None
get_step_run_metadata(info: StepRunInfo) -> Dict[str, MetadataType]

Get component- and step-specific metadata after a step ran.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def get_step_run_metadata(
    self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
    """Get component- and step-specific metadata after a step ran.

    Args:
        info: Info about the step that was executed.

    Returns:
        A dictionary of metadata.
    """
    metadata: Dict[str, Any] = {
        METADATA_EXPERIMENT_TRACKER_URL: Uri(
            self.get_tracking_uri(as_plain_text=False)
        ),
    }
    if run := mlflow.active_run():
        metadata["mlflow_run_id"] = run.info.run_id
        metadata["mlflow_experiment_id"] = run.info.experiment_id

    return metadata
get_tracking_uri(as_plain_text: bool = True) -> str

Returns the configured tracking URI or a local fallback.

Parameters:

Name Type Description Default
as_plain_text bool

Whether to return the tracking URI as plain text.

True

Returns:

Type Description
str

The tracking URI.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def get_tracking_uri(self, as_plain_text: bool = True) -> str:
    """Returns the configured tracking URI or a local fallback.

    Args:
        as_plain_text: Whether to return the tracking URI as plain text.

    Returns:
        The tracking URI.
    """
    if as_plain_text:
        tracking_uri = self.config.tracking_uri
    else:
        tracking_uri = self.config.model_dump()["tracking_uri"]
    return tracking_uri or self._local_mlflow_backend()
prepare_step_run(info: StepRunInfo) -> None

Sets the MLflow tracking uri and credentials.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that will be executed.

required
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def prepare_step_run(self, info: "StepRunInfo") -> None:
    """Sets the MLflow tracking uri and credentials.

    Args:
        info: Info about the step that will be executed.
    """
    self.configure_mlflow()
    settings = cast(
        MLFlowExperimentTrackerSettings,
        self.get_settings(info),
    )

    experiment_name = settings.experiment_name or info.pipeline.name
    experiment = self._set_active_experiment(experiment_name)
    run_id = self.get_run_id(
        experiment_name=experiment_name, run_name=info.run_name
    )

    tags = settings.tags.copy()
    tags.update(self._get_internal_tags())

    mlflow.start_run(
        run_id=run_id,
        run_name=info.run_name,
        experiment_id=experiment.experiment_id,
        tags=tags,
    )

    if settings.nested:
        mlflow.start_run(
            run_name=info.pipeline_step_name, nested=True, tags=tags
        )
Modules
mlflow_experiment_tracker

Implementation of the MLflow experiment tracker for ZenML.

Classes
MLFlowExperimentTracker(*args: Any, **kwargs: Any)

Bases: BaseExperimentTracker

Track experiments using MLflow.

Initialize the experiment tracker and validate the tracking uri.

Parameters:

Name Type Description Default
*args Any

Variable length argument list.

()
**kwargs Any

Arbitrary keyword arguments.

{}
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
66
67
68
69
70
71
72
73
74
def __init__(self, *args: Any, **kwargs: Any) -> None:
    """Initialize the experiment tracker and validate the tracking uri.

    Args:
        *args: Variable length argument list.
        **kwargs: Arbitrary keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self._ensure_valid_tracking_uri()
Attributes
config: MLFlowExperimentTrackerConfig property

Returns the MLFlowExperimentTrackerConfig config.

Returns:

Type Description
MLFlowExperimentTrackerConfig

The configuration.

local_path: Optional[str] property

Path to the local directory where the MLflow artifacts are stored.

Returns:

Type Description
Optional[str]

None if configured with a remote tracking URI, otherwise the

Optional[str]

path to the local MLflow artifact store directory.

settings_class: Optional[Type[BaseSettings]] property

Settings class for the Mlflow experiment tracker.

Returns:

Type Description
Optional[Type[BaseSettings]]

The settings class.

validator: Optional[StackValidator] property

Checks the stack has a LocalArtifactStore if no tracking uri was specified.

Returns:

Type Description
Optional[StackValidator]

An optional StackValidator.

Functions
cleanup_step_run(info: StepRunInfo, step_failed: bool) -> None

Stops active MLflow runs and resets the MLflow tracking uri.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required
step_failed bool

Whether the step failed or not.

required
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def cleanup_step_run(
    self,
    info: "StepRunInfo",
    step_failed: bool,
) -> None:
    """Stops active MLflow runs and resets the MLflow tracking uri.

    Args:
        info: Info about the step that was executed.
        step_failed: Whether the step failed or not.
    """
    status = "FAILED" if step_failed else "FINISHED"
    self.disable_autologging()
    mlflow_utils.stop_zenml_mlflow_runs(status)
    mlflow.set_tracking_uri("")
configure_mlflow() -> None

Configures the MLflow tracking URI and any additional credentials.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def configure_mlflow(self) -> None:
    """Configures the MLflow tracking URI and any additional credentials."""
    tracking_uri = self.get_tracking_uri()
    mlflow.set_tracking_uri(tracking_uri)

    if is_databricks_tracking_uri(tracking_uri):
        if self.config.databricks_host:
            os.environ[DATABRICKS_HOST] = self.config.databricks_host
        if self.config.tracking_username:
            os.environ[DATABRICKS_USERNAME] = self.config.tracking_username
        if self.config.tracking_password:
            os.environ[DATABRICKS_PASSWORD] = self.config.tracking_password
        if self.config.tracking_token:
            os.environ[DATABRICKS_TOKEN] = self.config.tracking_token
        if self.config.enable_unity_catalog:
            mlflow.set_registry_uri(DATABRICKS_UNITY_CATALOG)
    else:
        os.environ[MLFLOW_TRACKING_URI] = tracking_uri
        if self.config.tracking_username:
            os.environ[MLFLOW_TRACKING_USERNAME] = (
                self.config.tracking_username
            )
        if self.config.tracking_password:
            os.environ[MLFLOW_TRACKING_PASSWORD] = (
                self.config.tracking_password
            )
        if self.config.tracking_token:
            os.environ[MLFLOW_TRACKING_TOKEN] = self.config.tracking_token

    os.environ[MLFLOW_TRACKING_INSECURE_TLS] = (
        "true" if self.config.tracking_insecure_tls else "false"
    )
disable_autologging() -> None

Disables MLflow autologging for all supported frameworks.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def disable_autologging(self) -> None:
    """Disables MLflow autologging for all supported frameworks."""
    frameworks = [
        "tensorflow",
        "xgboost",
        "lightgbm",
        "statsmodels",
        "spark",
        "sklearn",
        "fastai",
        "pytorch",
    ]

    failed_frameworks = []

    for framework in frameworks:
        try:
            # Correctly prefix the module name with 'mlflow.'
            module_name = f"mlflow.{framework}"
            # Dynamically import the module corresponding to the framework
            module = importlib.import_module(module_name)
            # Call the autolog function with disable=True
            module.autolog(disable=True)
        except ImportError as e:
            # only log on mlflow relevant errors
            if "mlflow" in e.msg.lower():
                failed_frameworks.append(framework)
        except Exception:
            failed_frameworks.append(framework)

    if len(failed_frameworks) > 0:
        logger.warning(
            f"Failed to disable MLflow autologging for the following frameworks: "
            f"{failed_frameworks}."
        )
get_run_id(experiment_name: str, run_name: str) -> Optional[str]

Gets the if of a run with the given name and experiment.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment in which to search for the run.

required
run_name str

Name of the run to search.

required

Returns:

Type Description
Optional[str]

The id of the run if it exists.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]:
    """Gets the if of a run with the given name and experiment.

    Args:
        experiment_name: Name of the experiment in which to search for the
            run.
        run_name: Name of the run to search.

    Returns:
        The id of the run if it exists.
    """
    self.configure_mlflow()
    experiment_name = self._adjust_experiment_name(experiment_name)

    runs = mlflow.search_runs(
        experiment_names=[experiment_name],
        filter_string=f'tags.mlflow.runName = "{run_name}"',
        run_view_type=3,
        output_format="list",
    )
    if not runs:
        return None

    run: Run = runs[0]
    if mlflow_utils.is_zenml_run(run):
        return cast(str, run.info.run_id)
    else:
        return None
get_step_run_metadata(info: StepRunInfo) -> Dict[str, MetadataType]

Get component- and step-specific metadata after a step ran.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that was executed.

required

Returns:

Type Description
Dict[str, MetadataType]

A dictionary of metadata.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
def get_step_run_metadata(
    self, info: "StepRunInfo"
) -> Dict[str, "MetadataType"]:
    """Get component- and step-specific metadata after a step ran.

    Args:
        info: Info about the step that was executed.

    Returns:
        A dictionary of metadata.
    """
    metadata: Dict[str, Any] = {
        METADATA_EXPERIMENT_TRACKER_URL: Uri(
            self.get_tracking_uri(as_plain_text=False)
        ),
    }
    if run := mlflow.active_run():
        metadata["mlflow_run_id"] = run.info.run_id
        metadata["mlflow_experiment_id"] = run.info.experiment_id

    return metadata
get_tracking_uri(as_plain_text: bool = True) -> str

Returns the configured tracking URI or a local fallback.

Parameters:

Name Type Description Default
as_plain_text bool

Whether to return the tracking URI as plain text.

True

Returns:

Type Description
str

The tracking URI.

Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def get_tracking_uri(self, as_plain_text: bool = True) -> str:
    """Returns the configured tracking URI or a local fallback.

    Args:
        as_plain_text: Whether to return the tracking URI as plain text.

    Returns:
        The tracking URI.
    """
    if as_plain_text:
        tracking_uri = self.config.tracking_uri
    else:
        tracking_uri = self.config.model_dump()["tracking_uri"]
    return tracking_uri or self._local_mlflow_backend()
prepare_step_run(info: StepRunInfo) -> None

Sets the MLflow tracking uri and credentials.

Parameters:

Name Type Description Default
info StepRunInfo

Info about the step that will be executed.

required
Source code in src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def prepare_step_run(self, info: "StepRunInfo") -> None:
    """Sets the MLflow tracking uri and credentials.

    Args:
        info: Info about the step that will be executed.
    """
    self.configure_mlflow()
    settings = cast(
        MLFlowExperimentTrackerSettings,
        self.get_settings(info),
    )

    experiment_name = settings.experiment_name or info.pipeline.name
    experiment = self._set_active_experiment(experiment_name)
    run_id = self.get_run_id(
        experiment_name=experiment_name, run_name=info.run_name
    )

    tags = settings.tags.copy()
    tags.update(self._get_internal_tags())

    mlflow.start_run(
        run_id=run_id,
        run_name=info.run_name,
        experiment_id=experiment.experiment_id,
        tags=tags,
    )

    if settings.nested:
        mlflow.start_run(
            run_name=info.pipeline_step_name, nested=True, tags=tags
        )
Functions Modules

flavors

MLFlow integration flavors.

Classes
MLFlowExperimentTrackerConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings

Config for the MLflow experiment tracker.

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

databricks_host Optional[str]

The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if tracking_uri value is set to "databricks".

enable_unity_catalog bool

If True, will enable the Databricks Unity Catalog for logging and registering models.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_local: bool property

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowExperimentTrackerFlavor

Bases: BaseExperimentTrackerFlavor

Class for the MLFlowExperimentTrackerFlavor.

Attributes
config_class: Type[MLFlowExperimentTrackerConfig] property

Returns MLFlowExperimentTrackerConfig config class.

Returns:

Type Description
Type[MLFlowExperimentTrackerConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowExperimentTracker] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowExperimentTracker]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

MLFlowModelDeployerConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseModelDeployerConfig

Configuration for the MLflow model deployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service configuration, PID and log files are stored.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_local: bool property

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowModelDeployerFlavor

Bases: BaseModelDeployerFlavor

Model deployer flavor for MLflow models.

Attributes
config_class: Type[MLFlowModelDeployerConfig] property

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

MLFlowModelRegistryConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseModelRegistryConfig

Configuration for the MLflow model registry.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
MLFlowModelRegistryFlavor

Bases: BaseModelRegistryFlavor

Model registry flavor for MLflow models.

Attributes
config_class: Type[MLFlowModelRegistryConfig] property

Returns MLFlowModelRegistryConfig config class.

Returns:

Type Description
Type[MLFlowModelRegistryConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelRegistry] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelRegistry]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

Modules
mlflow_experiment_tracker_flavor

MLflow experiment tracker flavor.

Classes
MLFlowExperimentTrackerConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseExperimentTrackerConfig, MLFlowExperimentTrackerSettings

Config for the MLflow experiment tracker.

Attributes:

Name Type Description
tracking_uri Optional[str]

The uri of the mlflow tracking server. If no uri is set, your stack must contain a LocalArtifactStore and ZenML will point MLflow to a subdirectory of your artifact store instead.

tracking_username Optional[str]

Username for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_password Optional[str]

Password for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_token Optional[str]

Token for authenticating with the MLflow tracking server. When a remote tracking uri is specified, either tracking_token or tracking_username and tracking_password must be specified.

tracking_insecure_tls bool

Skips verification of TLS connection to the MLflow tracking server if set to True.

databricks_host Optional[str]

The host of the Databricks workspace with the MLflow managed server to connect to. This is only required if tracking_uri value is set to "databricks".

enable_unity_catalog bool

If True, will enable the Databricks Unity Catalog for logging and registering models.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_local: bool property

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowExperimentTrackerFlavor

Bases: BaseExperimentTrackerFlavor

Class for the MLFlowExperimentTrackerFlavor.

Attributes
config_class: Type[MLFlowExperimentTrackerConfig] property

Returns MLFlowExperimentTrackerConfig config class.

Returns:

Type Description
Type[MLFlowExperimentTrackerConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowExperimentTracker] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowExperimentTracker]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

MLFlowExperimentTrackerSettings(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseSettings

Settings for the MLflow experiment tracker.

Attributes:

Name Type Description
experiment_name Optional[str]

The MLflow experiment name.

nested bool

If True, will create a nested sub-run for the step.

tags Dict[str, Any]

Tags for the Mlflow run.

Source code in src/zenml/config/secret_reference_mixin.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references are only passed for valid fields.

    This method ensures that secret references are not passed for fields
    that explicitly prevent them or require pydantic validation.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using plain-text secrets.
        **kwargs: Arguments to initialize this object.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            or an attribute which explicitly disallows secret references
            is passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}`. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure values with secrets "
                    "here: https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if secret_utils.is_clear_text_field(field):
            raise ValueError(
                f"Passing the `{key}` attribute as a secret reference is "
                "not allowed."
            )

        requires_validation = has_validators(
            pydantic_class=self.__class__, field_name=key
        )
        if requires_validation:
            raise ValueError(
                f"Passing the attribute `{key}` as a secret reference is "
                "not allowed as additional validation is required for "
                "this attribute."
            )

    super().__init__(**kwargs)
Functions
is_databricks_tracking_uri(tracking_uri: str) -> bool

Checks whether the given tracking uri is a Databricks tracking uri.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is a Databricks tracking uri, False

bool

otherwise.

Source code in src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
48
49
50
51
52
53
54
55
56
57
58
def is_databricks_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is a Databricks tracking uri.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is a Databricks tracking uri, `False`
        otherwise.
    """
    return tracking_uri == "databricks"
is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool

Checks whether the given tracking uri is remote or not.

Parameters:

Name Type Description Default
tracking_uri str

The tracking uri to check.

required

Returns:

Type Description
bool

True if the tracking uri is remote, False otherwise.

Source code in src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py
34
35
36
37
38
39
40
41
42
43
44
45
def is_remote_mlflow_tracking_uri(tracking_uri: str) -> bool:
    """Checks whether the given tracking uri is remote or not.

    Args:
        tracking_uri: The tracking uri to check.

    Returns:
        `True` if the tracking uri is remote, `False` otherwise.
    """
    return any(
        tracking_uri.startswith(prefix) for prefix in ["http://", "https://"]
    ) or is_databricks_tracking_uri(tracking_uri)
mlflow_model_deployer_flavor

MLflow model deployer flavor.

Classes
MLFlowModelDeployerConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseModelDeployerConfig

Configuration for the MLflow model deployer.

Attributes:

Name Type Description
service_path str

the path where the local MLflow deployment service configuration, PID and log files are stored.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
Attributes
is_local: bool property

Checks if this stack component is running locally.

Returns:

Type Description
bool

True if this config is for a local component, False otherwise.

MLFlowModelDeployerFlavor

Bases: BaseModelDeployerFlavor

Model deployer flavor for MLflow models.

Attributes
config_class: Type[MLFlowModelDeployerConfig] property

Returns MLFlowModelDeployerConfig config class.

Returns:

Type Description
Type[MLFlowModelDeployerConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelDeployer] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelDeployer]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_model_registry_flavor

MLflow model registry flavor.

Classes
MLFlowModelRegistryConfig(warn_about_plain_text_secrets: bool = False, **kwargs: Any)

Bases: BaseModelRegistryConfig

Configuration for the MLflow model registry.

Source code in src/zenml/stack/stack_component.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self, warn_about_plain_text_secrets: bool = False, **kwargs: Any
) -> None:
    """Ensures that secret references don't clash with pydantic validation.

    StackComponents allow the specification of all their string attributes
    using secret references of the form `{{secret_name.key}}`. This however
    is only possible when the stack component does not perform any explicit
    validation of this attribute using pydantic validators. If this were
    the case, the validation would run on the secret reference and would
    fail or in the worst case, modify the secret reference and lead to
    unexpected behavior. This method ensures that no attributes that require
    custom pydantic validation are set as secret references.

    Args:
        warn_about_plain_text_secrets: If true, then warns about using
            plain-text secrets.
        **kwargs: Arguments to initialize this stack component.

    Raises:
        ValueError: If an attribute that requires custom pydantic validation
            is passed as a secret reference, or if the `name` attribute
            was passed as a secret reference.
    """
    for key, value in kwargs.items():
        try:
            field = self.__class__.model_fields[key]
        except KeyError:
            # Value for a private attribute or non-existing field, this
            # will fail during the upcoming pydantic validation
            continue

        if value is None:
            continue

        if not secret_utils.is_secret_reference(value):
            if (
                secret_utils.is_secret_field(field)
                and warn_about_plain_text_secrets
            ):
                logger.warning(
                    "You specified a plain-text value for the sensitive "
                    f"attribute `{key}` for a `{self.__class__.__name__}` "
                    "stack component. This is currently only a warning, "
                    "but future versions of ZenML will require you to pass "
                    "in sensitive information as secrets. Check out the "
                    "documentation on how to configure your stack "
                    "components with secrets here: "
                    "https://docs.zenml.io/getting-started/deploying-zenml/secret-management"
                )
            continue

        if pydantic_utils.has_validators(
            pydantic_class=self.__class__, field_name=key
        ):
            raise ValueError(
                f"Passing the stack component attribute `{key}` as a "
                "secret reference is not allowed as additional validation "
                "is required for this attribute."
            )

    super().__init__(**kwargs)
MLFlowModelRegistryFlavor

Bases: BaseModelRegistryFlavor

Model registry flavor for MLflow models.

Attributes
config_class: Type[MLFlowModelRegistryConfig] property

Returns MLFlowModelRegistryConfig config class.

Returns:

Type Description
Type[MLFlowModelRegistryConfig]

The config class.

docs_url: Optional[str] property

A url to point at docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor docs url.

implementation_class: Type[MLFlowModelRegistry] property

Implementation class for this flavor.

Returns:

Type Description
Type[MLFlowModelRegistry]

The implementation class.

logo_url: str property

A url to represent the flavor in the dashboard.

Returns:

Type Description
str

The flavor logo.

name: str property

Name of the flavor.

Returns:

Type Description
str

The name of the flavor.

sdk_docs_url: Optional[str] property

A url to point at SDK docs explaining this flavor.

Returns:

Type Description
Optional[str]

A flavor SDK docs url.

mlflow_utils

Implementation of utils specific to the MLflow integration.

Classes
Functions
get_missing_mlflow_experiment_tracker_error() -> ValueError

Returns description of how to add an MLflow experiment tracker to your stack.

Returns:

Name Type Description
ValueError ValueError

If no MLflow experiment tracker is registered in the active stack.

Source code in src/zenml/integrations/mlflow/mlflow_utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_missing_mlflow_experiment_tracker_error() -> ValueError:
    """Returns description of how to add an MLflow experiment tracker to your stack.

    Returns:
        ValueError: If no MLflow experiment tracker is registered in the active stack.
    """
    return ValueError(
        "The active stack needs to have a MLflow experiment tracker "
        "component registered to be able to track experiments using "
        "MLflow. You can create a new stack with a MLflow experiment "
        "tracker component or update your existing stack to add this "
        "component, e.g.:\n\n"
        "  'zenml experiment-tracker register mlflow_tracker "
        "--type=mlflow'\n"
        "  'zenml stack register stack-name -e mlflow_tracker ...'\n"
    )
get_tracking_uri() -> str

Gets the MLflow tracking URI from the active experiment tracking stack component.

noqa: DAR401

Returns:

Type Description
str

MLflow tracking URI.

Source code in src/zenml/integrations/mlflow/mlflow_utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def get_tracking_uri() -> str:
    """Gets the MLflow tracking URI from the active experiment tracking stack component.

    # noqa: DAR401

    Returns:
        MLflow tracking URI.
    """
    from zenml.integrations.mlflow.experiment_trackers.mlflow_experiment_tracker import (
        MLFlowExperimentTracker,
    )

    tracker = Client().active_stack.experiment_tracker
    if tracker is None or not isinstance(tracker, MLFlowExperimentTracker):
        raise get_missing_mlflow_experiment_tracker_error()

    return tracker.get_tracking_uri()
is_zenml_run(run: Run) -> bool

Checks if a MLflow run is a ZenML run or not.

Parameters:

Name Type Description Default
run Run

The run to check.

required

Returns:

Type Description
bool

If the run is a ZenML run.

Source code in src/zenml/integrations/mlflow/mlflow_utils.py
64
65
66
67
68
69
70
71
72
73
def is_zenml_run(run: Run) -> bool:
    """Checks if a MLflow run is a ZenML run or not.

    Args:
        run: The run to check.

    Returns:
        If the run is a ZenML run.
    """
    return ZENML_TAG_KEY in run.data.tags
stop_zenml_mlflow_runs(status: str) -> None

Stops active ZenML Mlflow runs.

This function stops all MLflow active runs until no active run exists or a non-ZenML run is active.

Parameters:

Name Type Description Default
status str

The status to set the run to.

required
Source code in src/zenml/integrations/mlflow/mlflow_utils.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def stop_zenml_mlflow_runs(status: str) -> None:
    """Stops active ZenML Mlflow runs.

    This function stops all MLflow active runs until no active run exists or
    a non-ZenML run is active.

    Args:
        status: The status to set the run to.
    """
    active_run = mlflow.active_run()
    while active_run:
        if is_zenml_run(active_run):
            logger.debug("Stopping mlflow run %s.", active_run.info.run_id)
            mlflow.end_run(status=status)
            active_run = mlflow.active_run()
        else:
            break

model_deployers

Initialization of the MLflow model deployers.

Classes
MLFlowModelDeployer(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseModelDeployer

MLflow implementation of the BaseModelDeployer.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: MLFlowModelDeployerConfig property

Returns the MLFlowModelDeployerConfig config.

Returns:

Type Description
MLFlowModelDeployerConfig

The configuration.

local_path: str property

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.

Returns:

Type Description
str

The path to the local service root directory.

Functions
get_model_server_info(service_instance: MLFlowDeploymentService) -> Dict[str, Optional[str]] staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
        "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
        "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
            service_instance.endpoint
        ),
    }
get_service_path(id_: UUID) -> str staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
id_ UUID

The ID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@staticmethod
def get_service_path(id_: UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files
    are stored.

    Args:
        id_: The ID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        GlobalConfiguration().local_stores_path,
        str(id_),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
perform_delete_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False) -> None

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
service BaseService

The service to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT
force bool

If True, force the service to stop.

False
Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def perform_delete_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        service: The service to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    service = cast(MLFlowDeploymentService, service)
    self._clean_up_existing_service(
        existing_service=service, timeout=timeout, force=force
    )
perform_deploy_model(id: UUID, config: ServiceConfig, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> BaseService

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
id UUID

the ID of the MLflow deployment service to be created or updated.

required
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to

BaseService

interact with the MLflow model server.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def perform_deploy_model(
    self,
    id: UUID,
    config: ServiceConfig,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        id: the ID of the MLflow deployment service to be created or updated.
        config: the configuration of the model to be deployed with MLflow.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = self._create_new_service(
        id=id, timeout=timeout, config=config
    )
    logger.info(f"Created a new MLflow deployment service: {service}")
    return service
perform_start_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> BaseService

Method to start a model server.

Parameters:

Name Type Description Default
service BaseService

The service to start.

required
timeout int

Timeout in seconds to wait for the service to start.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
BaseService

The service that was started.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def perform_start_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Method to start a model server.

    Args:
        service: The service to start.
        timeout: Timeout in seconds to wait for the service to start.

    Returns:
        The service that was started.
    """
    service.start(timeout=timeout)
    return service
perform_stop_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False) -> BaseService

Method to stop a model server.

Parameters:

Name Type Description Default
service BaseService

The service to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT
force bool

If True, force the service to stop.

False

Returns:

Type Description
BaseService

The service that was stopped.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def perform_stop_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> BaseService:
    """Method to stop a model server.

    Args:
        service: The service to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.

    Returns:
        The service that was stopped.
    """
    service.stop(timeout=timeout, force=force)
    return service
Modules
mlflow_model_deployer

Implementation of the MLflow model deployer.

Classes
MLFlowModelDeployer(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseModelDeployer

MLflow implementation of the BaseModelDeployer.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: MLFlowModelDeployerConfig property

Returns the MLFlowModelDeployerConfig config.

Returns:

Type Description
MLFlowModelDeployerConfig

The configuration.

local_path: str property

Returns the path to the root directory.

This is where all configurations for MLflow deployment daemon processes are stored.

If the service path is not set in the config by the user, the path is set to a local default path according to the component ID.

Returns:

Type Description
str

The path to the local service root directory.

Functions
get_model_server_info(service_instance: MLFlowDeploymentService) -> Dict[str, Optional[str]] staticmethod

Return implementation specific information relevant to the user.

Parameters:

Name Type Description Default
service_instance MLFlowDeploymentService

Instance of a SeldonDeploymentService

required

Returns:

Type Description
Dict[str, Optional[str]]

A dictionary containing the information.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@staticmethod
def get_model_server_info(  # type: ignore[override]
    service_instance: "MLFlowDeploymentService",
) -> Dict[str, Optional[str]]:
    """Return implementation specific information relevant to the user.

    Args:
        service_instance: Instance of a SeldonDeploymentService

    Returns:
        A dictionary containing the information.
    """
    return {
        "PREDICTION_URL": service_instance.endpoint.prediction_url,
        "MODEL_URI": service_instance.config.model_uri,
        "MODEL_NAME": service_instance.config.model_name,
        "REGISTRY_MODEL_NAME": service_instance.config.registry_model_name,
        "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version,
        "SERVICE_PATH": service_instance.status.runtime_path,
        "DAEMON_PID": str(service_instance.status.pid),
        "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri(
            service_instance.endpoint
        ),
    }
get_service_path(id_: UUID) -> str staticmethod

Get the path where local MLflow service information is stored.

This includes the deployment service configuration, PID and log files are stored.

Parameters:

Name Type Description Default
id_ UUID

The ID of the MLflow model deployer.

required

Returns:

Type Description
str

The service path.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@staticmethod
def get_service_path(id_: UUID) -> str:
    """Get the path where local MLflow service information is stored.

    This includes the deployment service configuration, PID and log files
    are stored.

    Args:
        id_: The ID of the MLflow model deployer.

    Returns:
        The service path.
    """
    service_path = os.path.join(
        GlobalConfiguration().local_stores_path,
        str(id_),
    )
    create_dir_recursive_if_not_exists(service_path)
    return service_path
perform_delete_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False) -> None

Method to delete all configuration of a model server.

Parameters:

Name Type Description Default
service BaseService

The service to delete.

required
timeout int

Timeout in seconds to wait for the service to stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT
force bool

If True, force the service to stop.

False
Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def perform_delete_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> None:
    """Method to delete all configuration of a model server.

    Args:
        service: The service to delete.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.
    """
    service = cast(MLFlowDeploymentService, service)
    self._clean_up_existing_service(
        existing_service=service, timeout=timeout, force=force
    )
perform_deploy_model(id: UUID, config: ServiceConfig, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> BaseService

Create a new MLflow deployment service or update an existing one.

This should serve the supplied model and deployment configuration.

This method has two modes of operation, depending on the replace argument value:

  • if replace is False, calling this method will create a new MLflow deployment server to reflect the model and other configuration parameters specified in the supplied MLflow service config.

  • if replace is True, this method will first attempt to find an existing MLflow deployment service that is equivalent to the supplied configuration parameters. Two or more MLflow deployment services are considered equivalent if they have the same pipeline_name, pipeline_step_name and model_name configuration parameters. To put it differently, two MLflow deployment services are equivalent if they serve versions of the same model deployed by the same pipeline step. If an equivalent MLflow deployment is found, it will be updated in place to reflect the new configuration parameters.

Callers should set replace to True if they want a continuous model deployment workflow that doesn't spin up a new MLflow deployment server for each new model version. If multiple equivalent MLflow deployment servers are found, one is selected at random to be updated and the others are deleted.

Parameters:

Name Type Description Default
id UUID

the ID of the MLflow deployment service to be created or updated.

required
config ServiceConfig

the configuration of the model to be deployed with MLflow.

required
timeout int

the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow server is provisioned, without waiting for it to fully start.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
BaseService

The ZenML MLflow deployment service object that can be used to

BaseService

interact with the MLflow model server.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
def perform_deploy_model(
    self,
    id: UUID,
    config: ServiceConfig,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Create a new MLflow deployment service or update an existing one.

    This should serve the supplied model and deployment configuration.

    This method has two modes of operation, depending on the `replace`
    argument value:

      * if `replace` is False, calling this method will create a new MLflow
        deployment server to reflect the model and other configuration
        parameters specified in the supplied MLflow service `config`.

      * if `replace` is True, this method will first attempt to find an
        existing MLflow deployment service that is *equivalent* to the
        supplied configuration parameters. Two or more MLflow deployment
        services are considered equivalent if they have the same
        `pipeline_name`, `pipeline_step_name` and `model_name` configuration
        parameters. To put it differently, two MLflow deployment services
        are equivalent if they serve versions of the same model deployed by
        the same pipeline step. If an equivalent MLflow deployment is found,
        it will be updated in place to reflect the new configuration
        parameters.

    Callers should set `replace` to True if they want a continuous model
    deployment workflow that doesn't spin up a new MLflow deployment
    server for each new model version. If multiple equivalent MLflow
    deployment servers are found, one is selected at random to be updated
    and the others are deleted.

    Args:
        id: the ID of the MLflow deployment service to be created or updated.
        config: the configuration of the model to be deployed with MLflow.
        timeout: the timeout in seconds to wait for the MLflow server
            to be provisioned and successfully started or updated. If set
            to 0, the method will return immediately after the MLflow
            server is provisioned, without waiting for it to fully start.

    Returns:
        The ZenML MLflow deployment service object that can be used to
        interact with the MLflow model server.
    """
    config = cast(MLFlowDeploymentConfig, config)
    service = self._create_new_service(
        id=id, timeout=timeout, config=config
    )
    logger.info(f"Created a new MLflow deployment service: {service}")
    return service
perform_start_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> BaseService

Method to start a model server.

Parameters:

Name Type Description Default
service BaseService

The service to start.

required
timeout int

Timeout in seconds to wait for the service to start.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
BaseService

The service that was started.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def perform_start_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> BaseService:
    """Method to start a model server.

    Args:
        service: The service to start.
        timeout: Timeout in seconds to wait for the service to start.

    Returns:
        The service that was started.
    """
    service.start(timeout=timeout)
    return service
perform_stop_model(service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False) -> BaseService

Method to stop a model server.

Parameters:

Name Type Description Default
service BaseService

The service to stop.

required
timeout int

Timeout in seconds to wait for the service to stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT
force bool

If True, force the service to stop.

False

Returns:

Type Description
BaseService

The service that was stopped.

Source code in src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def perform_stop_model(
    self,
    service: BaseService,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
    force: bool = False,
) -> BaseService:
    """Method to stop a model server.

    Args:
        service: The service to stop.
        timeout: Timeout in seconds to wait for the service to stop.
        force: If True, force the service to stop.

    Returns:
        The service that was stopped.
    """
    service.stop(timeout=timeout, force=force)
    return service
Functions

model_registries

Initialization of the MLflow model registry.

Classes
MLFlowModelRegistry(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseModelRegistry

Register models using MLflow.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: MLFlowModelRegistryConfig property

Returns the MLFlowModelRegistryConfig config.

Returns:

Type Description
MLFlowModelRegistryConfig

The configuration.

mlflow_client: MlflowClient property

Get the MLflow client.

Returns:

Type Description
MlflowClient

The MLFlowClient.

validator: Optional[StackValidator] property

Validates that the stack contains an mlflow experiment tracker.

Returns:

Type Description
Optional[StackValidator]

A StackValidator instance.

Functions
configure_mlflow() -> None

Configures the MLflow Client with the experiment tracker config.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
111
112
113
114
115
def configure_mlflow(self) -> None:
    """Configures the MLflow Client with the experiment tracker config."""
    experiment_tracker = Client().active_stack.experiment_tracker
    assert isinstance(experiment_tracker, MLFlowExperimentTracker)
    experiment_tracker.configure_mlflow()
delete_model(name: str) -> None

Delete a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Raises:

Type Description
RuntimeError

If the model does not exist.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def delete_model(
    self,
    name: str,
) -> None:
    """Delete a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Raises:
        RuntimeError: If the model does not exist.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Delete the registered model.
    try:
        self.mlflow_client.delete_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model with name {name} from MLflow model "
            f"registry: {str(e)}",
        )
delete_model_version(name: str, version: str) -> None

Delete a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Raises:

Type Description
RuntimeError

If mlflow fails to delete the model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
def delete_model_version(
    self,
    name: str,
    version: str,
) -> None:
    """Delete a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        RuntimeError: If mlflow fails to delete the model version.
    """
    self.get_model_version(name=name, version=version)
    try:
        self.mlflow_client.delete_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model version '{version}' of model '{name}'."
            f"From the MLflow model registry: {str(e)}",
        )
get_model(name: str) -> RegisteredModel

Get a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Returns:

Type Description
RegisteredModel

The model.

Raises:

Type Description
KeyError

If mlflow fails to get the model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def get_model(self, name: str) -> RegisteredModel:
    """Get a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Returns:
        The model.

    Raises:
        KeyError: If mlflow fails to get the model.
    """
    # Get the registered model.
    try:
        registered_model = self.mlflow_client.get_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model with name {name} from the MLflow model "
            f"registry: {str(e)}",
        )
    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
get_model_uri_artifact_store(model_version: RegistryModelVersion) -> str

Get the model URI artifact store.

Parameters:

Name Type Description Default
model_version RegistryModelVersion

The model version.

required

Returns:

Type Description
str

The model URI artifact store.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def get_model_uri_artifact_store(
    self,
    model_version: RegistryModelVersion,
) -> str:
    """Get the model URI artifact store.

    Args:
        model_version: The model version.

    Returns:
        The model URI artifact store.
    """
    artifact_store_path = (
        f"{Client().active_stack.artifact_store.path}/mlflow"
    )
    model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
    return artifact_store_path + model_source_uri
get_model_version(name: str, version: str) -> RegistryModelVersion

Get a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Raises:

Type Description
KeyError

If the model version does not exist.

Returns:

Type Description
RegistryModelVersion

The model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def get_model_version(
    self,
    name: str,
    version: str,
) -> RegistryModelVersion:
    """Get a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        KeyError: If the model version does not exist.

    Returns:
        The model version.
    """
    # Get the model version from the MLflow model registry.
    try:
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model version '{name}:{version}' from the "
            f"MLflow model registry: {str(e)}"
        )
    # Return the model version.
    return self._cast_mlflow_version_to_model_version(
        mlflow_model_version=mlflow_model_version,
    )
list_model_versions(name: Optional[str] = None, model_source_uri: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, stage: Optional[ModelVersionStage] = None, count: Optional[int] = None, created_after: Optional[datetime] = None, created_before: Optional[datetime] = None, order_by_date: Optional[str] = None, **kwargs: Any) -> List[RegistryModelVersion]

List model versions from the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

The name of the model.

None
model_source_uri Optional[str]

The model source URI.

None
metadata Optional[ModelRegistryModelMetadata]

The metadata of the model version.

None
stage Optional[ModelVersionStage]

The stage of the model version.

None
count Optional[int]

The maximum number of model versions to return.

None
created_after Optional[datetime]

The minimum creation time of the model versions.

None
created_before Optional[datetime]

The maximum creation time of the model versions.

None
order_by_date Optional[str]

The order of the model versions by creation time, either ascending or descending.

None
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[RegistryModelVersion]

The model versions.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def list_model_versions(
    self,
    name: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    stage: Optional[ModelVersionStage] = None,
    count: Optional[int] = None,
    created_after: Optional[datetime] = None,
    created_before: Optional[datetime] = None,
    order_by_date: Optional[str] = None,
    **kwargs: Any,
) -> List[RegistryModelVersion]:
    """List model versions from the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The model source URI.
        metadata: The metadata of the model version.
        stage: The stage of the model version.
        count: The maximum number of model versions to return.
        created_after: The minimum creation time of the model versions.
        created_before: The maximum creation time of the model versions.
        order_by_date: The order of the model versions by creation time,
            either ascending or descending.
        kwargs: Additional keyword arguments.

    Returns:
        The model versions.
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if model_source_uri:
        if filter_string:
            filter_string += " AND "
        filter_string += f"source='{model_source_uri}'"
    if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
        if filter_string:
            filter_string += " AND "
        filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
    if metadata:
        for tag, value in metadata.model_dump().items():
            if value:
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"
                # Get the model versions.
    order_by = []
    if order_by_date:
        if order_by_date in ["asc", "desc"]:
            if order_by_date == "asc":
                order_by = ["creation_timestamp ASC"]
            else:
                order_by = ["creation_timestamp DESC"]
    mlflow_model_versions = self.mlflow_client.search_model_versions(
        filter_string=filter_string,
        order_by=order_by,
    )
    # Cast the MLflow model versions to the ZenML model version class.
    model_versions = []
    for mlflow_model_version in mlflow_model_versions:
        # check if given MlFlow model version matches the given request
        # before casting it
        if (
            stage
            and not ModelVersionStage(mlflow_model_version.current_stage)
            == stage
        ):
            continue
        if created_after and not (
            mlflow_model_version.creation_timestamp
            >= created_after.timestamp()
        ):
            continue
        if created_before and not (
            mlflow_model_version.creation_timestamp
            <= created_before.timestamp()
        ):
            continue
        try:
            model_versions.append(
                self._cast_mlflow_version_to_model_version(
                    mlflow_model_version=mlflow_model_version,
                )
            )
        except (AttributeError, OSError) as e:
            # Sometimes, the Model Registry in MLflow can become unusable
            # due to failed version registration or misuse. In such rare
            # cases, it's best to suppress those versions that are not usable.
            logger.warning(
                "Error encountered while loading MLflow model version "
                f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
            )
        if count and len(model_versions) == count:
            return model_versions

    return model_versions
list_models(name: Optional[str] = None, metadata: Optional[Dict[str, str]] = None) -> List[RegisteredModel]

List models in the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

A name to filter the models by.

None
metadata Optional[Dict[str, str]]

The metadata to filter the models by.

None

Returns:

Type Description
List[RegisteredModel]

A list of models (RegisteredModel)

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def list_models(
    self,
    name: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
    """List models in the MLflow model registry.

    Args:
        name: A name to filter the models by.
        metadata: The metadata to filter the models by.

    Returns:
        A list of models (RegisteredModel)
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if metadata:
        for tag, value in metadata.items():
            if filter_string:
                filter_string += " AND "
            filter_string += f"tags.{tag}='{value}'"

    # Get the registered models.
    registered_models = self.mlflow_client.search_registered_models(
        filter_string=filter_string,
        max_results=100,
    )
    # Return the registered models.
    return [
        RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )
        for registered_model in registered_models
    ]
load_model_version(name: str, version: str, **kwargs: Any) -> Any

Load a model version from the MLflow model registry.

This method loads the model version from the MLflow model registry and returns the model. The model is loaded using the mlflow.pyfunc module which takes care of loading the model from the model source URI for the right framework.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Any

The model version.

Raises:

Type Description
KeyError

If the model version does not exist.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def load_model_version(
    self,
    name: str,
    version: str,
    **kwargs: Any,
) -> Any:
    """Load a model version from the MLflow model registry.

    This method loads the model version from the MLflow model registry
    and returns the model. The model is loaded using the `mlflow.pyfunc`
    module which takes care of loading the model from the model source
    URI for the right framework.

    Args:
        name: The name of the model.
        version: The version of the model.
        kwargs: Additional keyword arguments.

    Returns:
        The model version.

    Raises:
        KeyError: If the model version does not exist.
    """
    try:
        self.get_model_version(name=name, version=version)
    except KeyError:
        raise KeyError(
            f"Failed to load model version '{name}:{version}' from the "
            f"MLflow model registry: Model version does not exist."
        )
    # Load the model version.
    mlflow_model_version = self.mlflow_client.get_model_version(
        name=name,
        version=version,
    )
    return load_model(
        model_uri=mlflow_model_version.source,
        **kwargs,
    )
register_model(name: str, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None) -> RegisteredModel

Register a model to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None

Raises:

Type Description
RuntimeError

If the model already exists.

Returns:

Type Description
RegisteredModel

The registered model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def register_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
    """Register a model to the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.

    Raises:
        RuntimeError: If the model already exists.

    Returns:
        The registered model.
    """
    # Check if model already exists.
    try:
        self.get_model(name)
        raise KeyError(
            f"Model with name {name} already exists in the MLflow model "
            f"registry. Please use a different name.",
        )
    except KeyError:
        pass
    # Register model.
    try:
        registered_model = self.mlflow_client.create_registered_model(
            name=name,
            description=description,
            tags=metadata,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model with name {name} to the MLflow "
            f"model registry: {str(e)}",
        )

    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
register_model_version(name: str, version: Optional[str] = None, model_source_uri: Optional[str] = None, description: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, **kwargs: Any) -> RegistryModelVersion

Register a model version to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
model_source_uri Optional[str]

The source URI of the model.

None
version Optional[str]

The version of the model.

None
description Optional[str]

The description of the model version.

None
metadata Optional[ModelRegistryModelMetadata]

The registry metadata of the model version.

None
**kwargs Any

Additional keyword arguments.

{}

Raises:

Type Description
RuntimeError

If the registered model does not exist.

ValueError

If no model source URI was provided.

Returns:

Type Description
RegistryModelVersion

The registered model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def register_model_version(
    self,
    name: str,
    version: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    **kwargs: Any,
) -> RegistryModelVersion:
    """Register a model version to the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The source URI of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The registry metadata of the model version.
        **kwargs: Additional keyword arguments.

    Raises:
        RuntimeError: If the registered model does not exist.
        ValueError: If no model source URI was provided.

    Returns:
        The registered model version.
    """
    if not model_source_uri:
        raise ValueError(
            "Unable to register model version without model source URI."
        )

    # Check if the model exists, if not create it.
    try:
        self.get_model(name=name)
    except KeyError:
        logger.info(
            f"No registered model with name {name} found. Creating a new "
            "registered model."
        )
        self.register_model(
            name=name,
        )
    try:
        # Inform the user that the version is ignored.
        if version:
            logger.info(
                f"MLflow model registry does not take a version as an argument. "
                f"Registering a new version for the model `'{name}'` "
                f"a version will be assigned automatically."
            )
        metadata_dict = metadata.model_dump() if metadata else {}
        # Set the run ID and link.
        run_id = metadata_dict.get("mlflow_run_id", None)
        run_link = metadata_dict.get("mlflow_run_link", None)
        # Register the model version.
        registered_model_version = self.mlflow_client.create_model_version(
            name=name,
            source=model_source_uri,
            run_id=run_id,
            run_link=run_link,
            description=description,
            tags=metadata_dict,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model version with name '{name}' and "
            f"version '{version}' to the MLflow model registry."
            f"Error: {e}"
        )
    # Return the registered model version.
    return self._cast_mlflow_version_to_model_version(
        registered_model_version
    )
update_model(name: str, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, remove_metadata: Optional[List[str]] = None) -> RegisteredModel

Update a model in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model.

None

Raises:

Type Description
RuntimeError

If mlflow fails to update the model.

Returns:

Type Description
RegisteredModel

The updated model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def update_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
    remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
    """Update a model in the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.
        remove_metadata: The metadata to remove from the model.

    Raises:
        RuntimeError: If mlflow fails to update the model.

    Returns:
        The updated model.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Update the registered model description.
    if description:
        try:
            self.mlflow_client.update_registered_model(
                name=name,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update description for the model {name} in MLflow "
                f"model registry: {str(e)}",
            )
    # Update the registered model tags.
    if metadata:
        try:
            for tag, value in metadata.items():
                self.mlflow_client.set_registered_model_tag(
                    name=name,
                    key=tag,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Remove tags from the registered model.
    if remove_metadata:
        try:
            for tag in remove_metadata:
                self.mlflow_client.delete_registered_model_tag(
                    name=name,
                    key=tag,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Return the updated registered model.
    return self.get_model(name)
update_model_version(name: str, version: str, description: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, remove_metadata: Optional[List[str]] = None, stage: Optional[ModelVersionStage] = None) -> RegistryModelVersion

Update a model version in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
description Optional[str]

The description of the model version.

None
metadata Optional[ModelRegistryModelMetadata]

The metadata of the model version.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model version.

None
stage Optional[ModelVersionStage]

The stage of the model version.

None

Raises:

Type Description
RuntimeError

If mlflow fails to update the model version.

Returns:

Type Description
RegistryModelVersion

The updated model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def update_model_version(
    self,
    name: str,
    version: str,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    remove_metadata: Optional[List[str]] = None,
    stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
    """Update a model version in the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The metadata of the model version.
        remove_metadata: The metadata to remove from the model version.
        stage: The stage of the model version.

    Raises:
        RuntimeError: If mlflow fails to update the model version.

    Returns:
        The updated model version.
    """
    self.get_model_version(name=name, version=version)
    # Update the model description.
    if description:
        try:
            self.mlflow_client.update_model_version(
                name=name,
                version=version,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the description of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model tags.
    if metadata:
        try:
            for key, value in metadata.model_dump().items():
                self.mlflow_client.set_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Remove the model tags.
    if remove_metadata:
        try:
            for key in remove_metadata:
                self.mlflow_client.delete_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model stage.
    if stage:
        try:
            self.mlflow_client.transition_model_version_stage(
                name=name,
                version=version,
                stage=stage.value,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the current stage of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    return self.get_model_version(name, version)
Modules
mlflow_model_registry

Implementation of the MLflow model registry for ZenML.

Classes
MLFlowModelRegistry(name: str, id: UUID, config: StackComponentConfig, flavor: str, type: StackComponentType, user: Optional[UUID], created: datetime, updated: datetime, labels: Optional[Dict[str, Any]] = None, connector_requirements: Optional[ServiceConnectorRequirements] = None, connector: Optional[UUID] = None, connector_resource_id: Optional[str] = None, *args: Any, **kwargs: Any)

Bases: BaseModelRegistry

Register models using MLflow.

Source code in src/zenml/stack/stack_component.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def __init__(
    self,
    name: str,
    id: UUID,
    config: StackComponentConfig,
    flavor: str,
    type: StackComponentType,
    user: Optional[UUID],
    created: datetime,
    updated: datetime,
    labels: Optional[Dict[str, Any]] = None,
    connector_requirements: Optional[ServiceConnectorRequirements] = None,
    connector: Optional[UUID] = None,
    connector_resource_id: Optional[str] = None,
    *args: Any,
    **kwargs: Any,
):
    """Initializes a StackComponent.

    Args:
        name: The name of the component.
        id: The unique ID of the component.
        config: The config of the component.
        flavor: The flavor of the component.
        type: The type of the component.
        user: The ID of the user who created the component.
        created: The creation time of the component.
        updated: The last update time of the component.
        labels: The labels of the component.
        connector_requirements: The requirements for the connector.
        connector: The ID of a connector linked to the component.
        connector_resource_id: The custom resource ID to access through
            the connector.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.

    Raises:
        ValueError: If a secret reference is passed as name.
    """
    if secret_utils.is_secret_reference(name):
        raise ValueError(
            "Passing the `name` attribute of a stack component as a "
            "secret reference is not allowed."
        )

    self.id = id
    self.name = name
    self._config = config
    self.flavor = flavor
    self.type = type
    self.user = user
    self.created = created
    self.updated = updated
    self.labels = labels
    self.connector_requirements = connector_requirements
    self.connector = connector
    self.connector_resource_id = connector_resource_id
    self._connector_instance: Optional[ServiceConnector] = None
Attributes
config: MLFlowModelRegistryConfig property

Returns the MLFlowModelRegistryConfig config.

Returns:

Type Description
MLFlowModelRegistryConfig

The configuration.

mlflow_client: MlflowClient property

Get the MLflow client.

Returns:

Type Description
MlflowClient

The MLFlowClient.

validator: Optional[StackValidator] property

Validates that the stack contains an mlflow experiment tracker.

Returns:

Type Description
Optional[StackValidator]

A StackValidator instance.

Functions
configure_mlflow() -> None

Configures the MLflow Client with the experiment tracker config.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
111
112
113
114
115
def configure_mlflow(self) -> None:
    """Configures the MLflow Client with the experiment tracker config."""
    experiment_tracker = Client().active_stack.experiment_tracker
    assert isinstance(experiment_tracker, MLFlowExperimentTracker)
    experiment_tracker.configure_mlflow()
delete_model(name: str) -> None

Delete a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Raises:

Type Description
RuntimeError

If the model does not exist.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def delete_model(
    self,
    name: str,
) -> None:
    """Delete a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Raises:
        RuntimeError: If the model does not exist.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Delete the registered model.
    try:
        self.mlflow_client.delete_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model with name {name} from MLflow model "
            f"registry: {str(e)}",
        )
delete_model_version(name: str, version: str) -> None

Delete a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Raises:

Type Description
RuntimeError

If mlflow fails to delete the model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
def delete_model_version(
    self,
    name: str,
    version: str,
) -> None:
    """Delete a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        RuntimeError: If mlflow fails to delete the model version.
    """
    self.get_model_version(name=name, version=version)
    try:
        self.mlflow_client.delete_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to delete model version '{version}' of model '{name}'."
            f"From the MLflow model registry: {str(e)}",
        )
get_model(name: str) -> RegisteredModel

Get a model from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required

Returns:

Type Description
RegisteredModel

The model.

Raises:

Type Description
KeyError

If mlflow fails to get the model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def get_model(self, name: str) -> RegisteredModel:
    """Get a model from the MLflow model registry.

    Args:
        name: The name of the model.

    Returns:
        The model.

    Raises:
        KeyError: If mlflow fails to get the model.
    """
    # Get the registered model.
    try:
        registered_model = self.mlflow_client.get_registered_model(
            name=name,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model with name {name} from the MLflow model "
            f"registry: {str(e)}",
        )
    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
get_model_uri_artifact_store(model_version: RegistryModelVersion) -> str

Get the model URI artifact store.

Parameters:

Name Type Description Default
model_version RegistryModelVersion

The model version.

required

Returns:

Type Description
str

The model URI artifact store.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
def get_model_uri_artifact_store(
    self,
    model_version: RegistryModelVersion,
) -> str:
    """Get the model URI artifact store.

    Args:
        model_version: The model version.

    Returns:
        The model URI artifact store.
    """
    artifact_store_path = (
        f"{Client().active_stack.artifact_store.path}/mlflow"
    )
    model_source_uri = model_version.model_source_uri.rsplit(":")[-1]
    return artifact_store_path + model_source_uri
get_model_version(name: str, version: str) -> RegistryModelVersion

Get a model version from the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required

Raises:

Type Description
KeyError

If the model version does not exist.

Returns:

Type Description
RegistryModelVersion

The model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
def get_model_version(
    self,
    name: str,
    version: str,
) -> RegistryModelVersion:
    """Get a model version from the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.

    Raises:
        KeyError: If the model version does not exist.

    Returns:
        The model version.
    """
    # Get the model version from the MLflow model registry.
    try:
        mlflow_model_version = self.mlflow_client.get_model_version(
            name=name,
            version=version,
        )
    except MlflowException as e:
        raise KeyError(
            f"Failed to get model version '{name}:{version}' from the "
            f"MLflow model registry: {str(e)}"
        )
    # Return the model version.
    return self._cast_mlflow_version_to_model_version(
        mlflow_model_version=mlflow_model_version,
    )
list_model_versions(name: Optional[str] = None, model_source_uri: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, stage: Optional[ModelVersionStage] = None, count: Optional[int] = None, created_after: Optional[datetime] = None, created_before: Optional[datetime] = None, order_by_date: Optional[str] = None, **kwargs: Any) -> List[RegistryModelVersion]

List model versions from the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

The name of the model.

None
model_source_uri Optional[str]

The model source URI.

None
metadata Optional[ModelRegistryModelMetadata]

The metadata of the model version.

None
stage Optional[ModelVersionStage]

The stage of the model version.

None
count Optional[int]

The maximum number of model versions to return.

None
created_after Optional[datetime]

The minimum creation time of the model versions.

None
created_before Optional[datetime]

The maximum creation time of the model versions.

None
order_by_date Optional[str]

The order of the model versions by creation time, either ascending or descending.

None
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
List[RegistryModelVersion]

The model versions.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
def list_model_versions(
    self,
    name: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    stage: Optional[ModelVersionStage] = None,
    count: Optional[int] = None,
    created_after: Optional[datetime] = None,
    created_before: Optional[datetime] = None,
    order_by_date: Optional[str] = None,
    **kwargs: Any,
) -> List[RegistryModelVersion]:
    """List model versions from the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The model source URI.
        metadata: The metadata of the model version.
        stage: The stage of the model version.
        count: The maximum number of model versions to return.
        created_after: The minimum creation time of the model versions.
        created_before: The maximum creation time of the model versions.
        order_by_date: The order of the model versions by creation time,
            either ascending or descending.
        kwargs: Additional keyword arguments.

    Returns:
        The model versions.
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if model_source_uri:
        if filter_string:
            filter_string += " AND "
        filter_string += f"source='{model_source_uri}'"
    if "mlflow_run_id" in kwargs and kwargs["mlflow_run_id"]:
        if filter_string:
            filter_string += " AND "
        filter_string += f"run_id='{kwargs['mlflow_run_id']}'"
    if metadata:
        for tag, value in metadata.model_dump().items():
            if value:
                if filter_string:
                    filter_string += " AND "
                filter_string += f"tags.{tag}='{value}'"
                # Get the model versions.
    order_by = []
    if order_by_date:
        if order_by_date in ["asc", "desc"]:
            if order_by_date == "asc":
                order_by = ["creation_timestamp ASC"]
            else:
                order_by = ["creation_timestamp DESC"]
    mlflow_model_versions = self.mlflow_client.search_model_versions(
        filter_string=filter_string,
        order_by=order_by,
    )
    # Cast the MLflow model versions to the ZenML model version class.
    model_versions = []
    for mlflow_model_version in mlflow_model_versions:
        # check if given MlFlow model version matches the given request
        # before casting it
        if (
            stage
            and not ModelVersionStage(mlflow_model_version.current_stage)
            == stage
        ):
            continue
        if created_after and not (
            mlflow_model_version.creation_timestamp
            >= created_after.timestamp()
        ):
            continue
        if created_before and not (
            mlflow_model_version.creation_timestamp
            <= created_before.timestamp()
        ):
            continue
        try:
            model_versions.append(
                self._cast_mlflow_version_to_model_version(
                    mlflow_model_version=mlflow_model_version,
                )
            )
        except (AttributeError, OSError) as e:
            # Sometimes, the Model Registry in MLflow can become unusable
            # due to failed version registration or misuse. In such rare
            # cases, it's best to suppress those versions that are not usable.
            logger.warning(
                "Error encountered while loading MLflow model version "
                f"`{mlflow_model_version.name}:{mlflow_model_version.version}`: {e}"
            )
        if count and len(model_versions) == count:
            return model_versions

    return model_versions
list_models(name: Optional[str] = None, metadata: Optional[Dict[str, str]] = None) -> List[RegisteredModel]

List models in the MLflow model registry.

Parameters:

Name Type Description Default
name Optional[str]

A name to filter the models by.

None
metadata Optional[Dict[str, str]]

The metadata to filter the models by.

None

Returns:

Type Description
List[RegisteredModel]

A list of models (RegisteredModel)

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def list_models(
    self,
    name: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> List[RegisteredModel]:
    """List models in the MLflow model registry.

    Args:
        name: A name to filter the models by.
        metadata: The metadata to filter the models by.

    Returns:
        A list of models (RegisteredModel)
    """
    # Set the filter string.
    filter_string = ""
    if name:
        filter_string += f"name='{name}'"
    if metadata:
        for tag, value in metadata.items():
            if filter_string:
                filter_string += " AND "
            filter_string += f"tags.{tag}='{value}'"

    # Get the registered models.
    registered_models = self.mlflow_client.search_registered_models(
        filter_string=filter_string,
        max_results=100,
    )
    # Return the registered models.
    return [
        RegisteredModel(
            name=registered_model.name,
            description=registered_model.description,
            metadata=registered_model.tags,
        )
        for registered_model in registered_models
    ]
load_model_version(name: str, version: str, **kwargs: Any) -> Any

Load a model version from the MLflow model registry.

This method loads the model version from the MLflow model registry and returns the model. The model is loaded using the mlflow.pyfunc module which takes care of loading the model from the model source URI for the right framework.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
kwargs Any

Additional keyword arguments.

{}

Returns:

Type Description
Any

The model version.

Raises:

Type Description
KeyError

If the model version does not exist.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def load_model_version(
    self,
    name: str,
    version: str,
    **kwargs: Any,
) -> Any:
    """Load a model version from the MLflow model registry.

    This method loads the model version from the MLflow model registry
    and returns the model. The model is loaded using the `mlflow.pyfunc`
    module which takes care of loading the model from the model source
    URI for the right framework.

    Args:
        name: The name of the model.
        version: The version of the model.
        kwargs: Additional keyword arguments.

    Returns:
        The model version.

    Raises:
        KeyError: If the model version does not exist.
    """
    try:
        self.get_model_version(name=name, version=version)
    except KeyError:
        raise KeyError(
            f"Failed to load model version '{name}:{version}' from the "
            f"MLflow model registry: Model version does not exist."
        )
    # Load the model version.
    mlflow_model_version = self.mlflow_client.get_model_version(
        name=name,
        version=version,
    )
    return load_model(
        model_uri=mlflow_model_version.source,
        **kwargs,
    )
register_model(name: str, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None) -> RegisteredModel

Register a model to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None

Raises:

Type Description
RuntimeError

If the model already exists.

Returns:

Type Description
RegisteredModel

The registered model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def register_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
) -> RegisteredModel:
    """Register a model to the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.

    Raises:
        RuntimeError: If the model already exists.

    Returns:
        The registered model.
    """
    # Check if model already exists.
    try:
        self.get_model(name)
        raise KeyError(
            f"Model with name {name} already exists in the MLflow model "
            f"registry. Please use a different name.",
        )
    except KeyError:
        pass
    # Register model.
    try:
        registered_model = self.mlflow_client.create_registered_model(
            name=name,
            description=description,
            tags=metadata,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model with name {name} to the MLflow "
            f"model registry: {str(e)}",
        )

    # Return the registered model.
    return RegisteredModel(
        name=registered_model.name,
        description=registered_model.description,
        metadata=registered_model.tags,
    )
register_model_version(name: str, version: Optional[str] = None, model_source_uri: Optional[str] = None, description: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, **kwargs: Any) -> RegistryModelVersion

Register a model version to the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
model_source_uri Optional[str]

The source URI of the model.

None
version Optional[str]

The version of the model.

None
description Optional[str]

The description of the model version.

None
metadata Optional[ModelRegistryModelMetadata]

The registry metadata of the model version.

None
**kwargs Any

Additional keyword arguments.

{}

Raises:

Type Description
RuntimeError

If the registered model does not exist.

ValueError

If no model source URI was provided.

Returns:

Type Description
RegistryModelVersion

The registered model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def register_model_version(
    self,
    name: str,
    version: Optional[str] = None,
    model_source_uri: Optional[str] = None,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    **kwargs: Any,
) -> RegistryModelVersion:
    """Register a model version to the MLflow model registry.

    Args:
        name: The name of the model.
        model_source_uri: The source URI of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The registry metadata of the model version.
        **kwargs: Additional keyword arguments.

    Raises:
        RuntimeError: If the registered model does not exist.
        ValueError: If no model source URI was provided.

    Returns:
        The registered model version.
    """
    if not model_source_uri:
        raise ValueError(
            "Unable to register model version without model source URI."
        )

    # Check if the model exists, if not create it.
    try:
        self.get_model(name=name)
    except KeyError:
        logger.info(
            f"No registered model with name {name} found. Creating a new "
            "registered model."
        )
        self.register_model(
            name=name,
        )
    try:
        # Inform the user that the version is ignored.
        if version:
            logger.info(
                f"MLflow model registry does not take a version as an argument. "
                f"Registering a new version for the model `'{name}'` "
                f"a version will be assigned automatically."
            )
        metadata_dict = metadata.model_dump() if metadata else {}
        # Set the run ID and link.
        run_id = metadata_dict.get("mlflow_run_id", None)
        run_link = metadata_dict.get("mlflow_run_link", None)
        # Register the model version.
        registered_model_version = self.mlflow_client.create_model_version(
            name=name,
            source=model_source_uri,
            run_id=run_id,
            run_link=run_link,
            description=description,
            tags=metadata_dict,
        )
    except MlflowException as e:
        raise RuntimeError(
            f"Failed to register model version with name '{name}' and "
            f"version '{version}' to the MLflow model registry."
            f"Error: {e}"
        )
    # Return the registered model version.
    return self._cast_mlflow_version_to_model_version(
        registered_model_version
    )
update_model(name: str, description: Optional[str] = None, metadata: Optional[Dict[str, str]] = None, remove_metadata: Optional[List[str]] = None) -> RegisteredModel

Update a model in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
description Optional[str]

The description of the model.

None
metadata Optional[Dict[str, str]]

The metadata of the model.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model.

None

Raises:

Type Description
RuntimeError

If mlflow fails to update the model.

Returns:

Type Description
RegisteredModel

The updated model.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def update_model(
    self,
    name: str,
    description: Optional[str] = None,
    metadata: Optional[Dict[str, str]] = None,
    remove_metadata: Optional[List[str]] = None,
) -> RegisteredModel:
    """Update a model in the MLflow model registry.

    Args:
        name: The name of the model.
        description: The description of the model.
        metadata: The metadata of the model.
        remove_metadata: The metadata to remove from the model.

    Raises:
        RuntimeError: If mlflow fails to update the model.

    Returns:
        The updated model.
    """
    # Check if model exists.
    self.get_model(name=name)
    # Update the registered model description.
    if description:
        try:
            self.mlflow_client.update_registered_model(
                name=name,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update description for the model {name} in MLflow "
                f"model registry: {str(e)}",
            )
    # Update the registered model tags.
    if metadata:
        try:
            for tag, value in metadata.items():
                self.mlflow_client.set_registered_model_tag(
                    name=name,
                    key=tag,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Remove tags from the registered model.
    if remove_metadata:
        try:
            for tag in remove_metadata:
                self.mlflow_client.delete_registered_model_tag(
                    name=name,
                    key=tag,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove tags for the model {name} in MLflow model "
                f"registry: {str(e)}",
            )
    # Return the updated registered model.
    return self.get_model(name)
update_model_version(name: str, version: str, description: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None, remove_metadata: Optional[List[str]] = None, stage: Optional[ModelVersionStage] = None) -> RegistryModelVersion

Update a model version in the MLflow model registry.

Parameters:

Name Type Description Default
name str

The name of the model.

required
version str

The version of the model.

required
description Optional[str]

The description of the model version.

None
metadata Optional[ModelRegistryModelMetadata]

The metadata of the model version.

None
remove_metadata Optional[List[str]]

The metadata to remove from the model version.

None
stage Optional[ModelVersionStage]

The stage of the model version.

None

Raises:

Type Description
RuntimeError

If mlflow fails to update the model version.

Returns:

Type Description
RegistryModelVersion

The updated model version.

Source code in src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def update_model_version(
    self,
    name: str,
    version: str,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
    remove_metadata: Optional[List[str]] = None,
    stage: Optional[ModelVersionStage] = None,
) -> RegistryModelVersion:
    """Update a model version in the MLflow model registry.

    Args:
        name: The name of the model.
        version: The version of the model.
        description: The description of the model version.
        metadata: The metadata of the model version.
        remove_metadata: The metadata to remove from the model version.
        stage: The stage of the model version.

    Raises:
        RuntimeError: If mlflow fails to update the model version.

    Returns:
        The updated model version.
    """
    self.get_model_version(name=name, version=version)
    # Update the model description.
    if description:
        try:
            self.mlflow_client.update_model_version(
                name=name,
                version=version,
                description=description,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the description of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model tags.
    if metadata:
        try:
            for key, value in metadata.model_dump().items():
                self.mlflow_client.set_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                    value=value,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Remove the model tags.
    if remove_metadata:
        try:
            for key in remove_metadata:
                self.mlflow_client.delete_model_version_tag(
                    name=name,
                    version=version,
                    key=key,
                )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to remove the tags of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    # Update the model stage.
    if stage:
        try:
            self.mlflow_client.transition_model_version_stage(
                name=name,
                version=version,
                stage=stage.value,
            )
        except MlflowException as e:
            raise RuntimeError(
                f"Failed to update the current stage of model version "
                f"'{name}:{version}' in the MLflow model registry: {str(e)}"
            )
    return self.get_model_version(name, version)
Functions

services

Initialization of the MLflow Service.

Classes
Modules
mlflow_deployment

Implementation of the MLflow deployment functionality.

Classes
MLFlowDeploymentConfig(**data: Any)

Bases: LocalDaemonServiceConfig

MLflow model deployment configuration.

Attributes:

Name Type Description
model_uri str

URI of the MLflow model to serve

model_name str

the name of the model

workers int

number of workers to use for the prediction service

registry_model_name Optional[str]

the name of the model in the registry

registry_model_version Optional[str]

the version of the model in the registry

mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

timeout int

timeout in seconds for starting and stopping the service

Source code in src/zenml/services/service.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(self, **data: Any):
    """Initialize the service configuration.

    Args:
        **data: keyword arguments.

    Raises:
        ValueError: if neither 'name' nor 'model_name' is set.
    """
    super().__init__(**data)
    if self.name or self.model_name:
        self.service_name = data.get(
            "service_name",
            f"{ZENM_ENDPOINT_PREFIX}{self.name or self.model_name}",
        )
    else:
        raise ValueError("Either 'name' or 'model_name' must be set.")
Functions
validate_mlserver_python_version(mlserver: bool) -> bool classmethod

Validates the Python version if mlserver is used.

Parameters:

Name Type Description Default
mlserver bool

set to True if the MLflow MLServer backend is used, else set to False and MLflow built-in scoring server will be used.

required

Returns:

Type Description
bool

the validated value

Raises:

Type Description
ValueError

if mlserver packages are not installed

Source code in src/zenml/integrations/mlflow/services/mlflow_deployment.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
@field_validator("mlserver")
@classmethod
def validate_mlserver_python_version(cls, mlserver: bool) -> bool:
    """Validates the Python version if mlserver is used.

    Args:
        mlserver: set to True if the MLflow MLServer backend is used,
            else set to False and MLflow built-in scoring server will be
            used.

    Returns:
        the validated value

    Raises:
        ValueError: if mlserver packages are not installed
    """
    if mlserver is True:
        # For the mlserver deployment, the mlserver and
        # mlserver-mlflow packages need to be installed separately
        # because they rely on an older version of Pydantic.

        # Check if the mlserver and mlserver-mlflow packages are installed
        try:
            importlib.import_module("mlserver")
            importlib.import_module("mlserver_mlflow")
        except ModuleNotFoundError:
            if sys.version_info.minor >= 12:
                raise ValueError(
                    "The mlserver deployment is not yet supported on "
                    "Python 3.12 or above."
                )

            raise ValueError(
                "The MLflow MLServer backend requires the `mlserver` and "
                "`mlserver-mlflow` packages to be installed. These "
                "packages are not included in the ZenML MLflow integration "
                "requirements. Please install them manually."
            )

    return mlserver
MLFlowDeploymentEndpoint(*args: Any, **kwargs: Any)

Bases: LocalDaemonServiceEndpoint

A service endpoint exposed by the MLflow deployment daemon.

Attributes:

Name Type Description
config MLFlowDeploymentEndpointConfig

service endpoint configuration

monitor HTTPEndpointHealthMonitor

optional service endpoint health monitor

Source code in src/zenml/services/service_endpoint.py
111
112
113
114
115
116
117
118
119
120
121
122
123
def __init__(
    self,
    *args: Any,
    **kwargs: Any,
) -> None:
    """Initialize the service endpoint.

    Args:
        *args: positional arguments.
        **kwargs: keyword arguments.
    """
    super().__init__(*args, **kwargs)
    self.config.name = self.config.name or self.__class__.__name__
Attributes
prediction_url: Optional[str] property

Gets the prediction URL for the endpoint.

Returns:

Type Description
Optional[str]

the prediction URL for the endpoint

MLFlowDeploymentEndpointConfig

Bases: LocalDaemonServiceEndpointConfig

MLflow daemon service endpoint configuration.

Attributes:

Name Type Description
prediction_url_path str

URI subpath for prediction requests

MLFlowDeploymentService(config: Union[MLFlowDeploymentConfig, Dict[str, Any]], **attrs: Any)

Bases: LocalDaemonService, BaseDeploymentService

MLflow deployment service used to start a local prediction server for MLflow models.

Attributes:

Name Type Description
SERVICE_TYPE

a service type descriptor with information describing the MLflow deployment service class

config MLFlowDeploymentConfig

service configuration

endpoint MLFlowDeploymentEndpoint

optional service endpoint

Initialize the MLflow deployment service.

Parameters:

Name Type Description Default
config Union[MLFlowDeploymentConfig, Dict[str, Any]]

service configuration

required
attrs Any

additional attributes to set on the service

{}
Source code in src/zenml/integrations/mlflow/services/mlflow_deployment.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def __init__(
    self,
    config: Union[MLFlowDeploymentConfig, Dict[str, Any]],
    **attrs: Any,
) -> None:
    """Initialize the MLflow deployment service.

    Args:
        config: service configuration
        attrs: additional attributes to set on the service
    """
    # ensure that the endpoint is created before the service is initialized
    # TODO [ENG-700]: implement a service factory or builder for MLflow
    #   deployment services
    if (
        isinstance(config, MLFlowDeploymentConfig)
        and "endpoint" not in attrs
    ):
        if config.mlserver:
            prediction_url_path = MLSERVER_PREDICTION_URL_PATH
            healthcheck_uri_path = MLSERVER_HEALTHCHECK_URL_PATH
            use_head_request = False
        else:
            prediction_url_path = MLFLOW_PREDICTION_URL_PATH
            healthcheck_uri_path = MLFLOW_HEALTHCHECK_URL_PATH
            use_head_request = True

        endpoint = MLFlowDeploymentEndpoint(
            config=MLFlowDeploymentEndpointConfig(
                protocol=ServiceEndpointProtocol.HTTP,
                prediction_url_path=prediction_url_path,
            ),
            monitor=HTTPEndpointHealthMonitor(
                config=HTTPEndpointHealthMonitorConfig(
                    healthcheck_uri_path=healthcheck_uri_path,
                    use_head_request=use_head_request,
                )
            ),
        )
        attrs["endpoint"] = endpoint
    super().__init__(config=config, **attrs)
Attributes
prediction_url: Optional[str] property

Get the URI where the prediction service is answering requests.

Returns:

Type Description
Optional[str]

The URI where the prediction service can be contacted to process

Optional[str]

HTTP/REST inference requests, or None, if the service isn't running.

Functions
predict(request: Union[NDArray[Any], pd.DataFrame]) -> NDArray[Any]

Make a prediction using the service.

Parameters:

Name Type Description Default
request Union[NDArray[Any], DataFrame]

a Numpy Array or Pandas DataFrame representing the request

required

Returns:

Type Description
NDArray[Any]

A numpy array representing the prediction returned by the service.

Raises:

Type Description
Exception

if the service is not running

ValueError

if the prediction endpoint is unknown.

Source code in src/zenml/integrations/mlflow/services/mlflow_deployment.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
def predict(
    self, request: Union["NDArray[Any]", pd.DataFrame]
) -> "NDArray[Any]":
    """Make a prediction using the service.

    Args:
        request: a Numpy Array or Pandas DataFrame representing the request

    Returns:
        A numpy array representing the prediction returned by the service.

    Raises:
        Exception: if the service is not running
        ValueError: if the prediction endpoint is unknown.
    """
    if not self.is_running:
        raise Exception(
            "MLflow prediction service is not running. "
            "Please start the service before making predictions."
        )

    if self.endpoint.prediction_url is not None:
        if type(request) is pd.DataFrame:
            response = requests.post(  # nosec
                self.endpoint.prediction_url,
                json={"instances": request.to_dict("records")},
            )
        else:
            response = requests.post(  # nosec
                self.endpoint.prediction_url,
                json={"instances": request.tolist()},
            )
    else:
        raise ValueError("No endpoint known for prediction.")
    response.raise_for_status()
    if int(MLFLOW_VERSION.split(".")[0]) <= 1:
        return np.array(response.json())
    else:
        # Mlflow 2.0+ returns a dictionary with the predictions
        # under the "predictions" key
        return np.array(response.json()["predictions"])
run() -> None

Start the service.

Raises:

Type Description
ValueError

if the active stack doesn't have an MLflow experiment tracker

Source code in src/zenml/integrations/mlflow/services/mlflow_deployment.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def run(self) -> None:
    """Start the service.

    Raises:
        ValueError: if the active stack doesn't have an MLflow experiment
            tracker
    """
    logger.info(
        "Starting MLflow prediction service as blocking "
        "process... press CTRL+C once to stop it."
    )

    self.endpoint.prepare_for_start()
    try:
        backend_kwargs: Dict[str, Any] = {}
        serve_kwargs: Dict[str, Any] = {}
        mlflow_version = MLFLOW_VERSION.split(".")
        # MLflow version 1.26 introduces an additional mandatory
        # `timeout` argument to the `PyFuncBackend.serve` function
        if int(mlflow_version[1]) >= 26 or int(mlflow_version[0]) >= 2:
            serve_kwargs["timeout"] = None
        # Mlflow 2.0+ requires the env_manager to be set to "local"
        # to run the deploy the model on the local running environment
        if int(mlflow_version[0]) >= 2:
            backend_kwargs["env_manager"] = "local"
        backend = PyFuncBackend(  # type: ignore[no-untyped-call]
            config={},
            no_conda=True,
            workers=self.config.workers,
            install_mlflow=False,
            **backend_kwargs,
        )
        experiment_tracker = Client().active_stack.experiment_tracker
        if not isinstance(experiment_tracker, MLFlowExperimentTracker):
            raise ValueError(
                "MLflow model deployer step requires an MLflow experiment "
                "tracker. Please add an MLflow experiment tracker to your "
                "stack."
            )
        experiment_tracker.configure_mlflow()
        backend.serve(  # type: ignore[no-untyped-call]
            model_uri=self.config.model_uri,
            port=self.endpoint.status.port,
            host="localhost",
            enable_mlserver=self.config.mlserver,
            **serve_kwargs,
        )
    except KeyboardInterrupt:
        logger.info(
            "MLflow prediction service stopped. Resuming normal execution."
        )
Functions

steps

Initialization of the MLflow standard interface steps.

Functions
Modules
mlflow_deployer

Implementation of the MLflow model deployer pipeline step.

Classes Functions
mlflow_model_deployer_step(model: UnmaterializedArtifact, deploy_decision: bool = True, experiment_name: Optional[str] = None, run_name: Optional[str] = None, model_name: str = 'model', workers: int = 1, mlserver: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> Optional[MLFlowDeploymentService]

Model deployer pipeline step for MLflow.

This step deploys a model logged in the MLflow artifact store to a deployment service. The user would typically use this step in a pipeline that deploys a model that was already registered in the MLflow model registr either manually or by using the mlflow_model_registry_step.

Parameters:

Name Type Description Default
deploy_decision bool

whether to deploy the model or not

True
model UnmaterializedArtifact

the model artifact to deploy

required
experiment_name Optional[str]

Name of the MLflow experiment in which the model was logged.

None
run_name Optional[str]

Name of the MLflow run in which the model was logged.

None
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

'model'
workers int

number of workers to use for the prediction service

1
mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

False
timeout int

the number of seconds to wait for the service to start/stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
Optional[MLFlowDeploymentService]

MLflow deployment service or None if no service was deployed or found

Raises:

Type Description
ValueError

if the MLflow experiment tracker is not found

Source code in src/zenml/integrations/mlflow/steps/mlflow_deployer.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@step(enable_cache=False)
def mlflow_model_deployer_step(
    model: UnmaterializedArtifact,
    deploy_decision: bool = True,
    experiment_name: Optional[str] = None,
    run_name: Optional[str] = None,
    model_name: str = "model",
    workers: int = 1,
    mlserver: bool = False,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> Optional[MLFlowDeploymentService]:
    """Model deployer pipeline step for MLflow.

    This step deploys a model logged in the MLflow artifact store to a
    deployment service. The user would typically use this step in a pipeline
    that deploys a model that was already registered in the MLflow model
    registr either manually or by using the `mlflow_model_registry_step`.

    Args:
        deploy_decision: whether to deploy the model or not
        model: the model artifact to deploy
        experiment_name: Name of the MLflow experiment in which the model was
            logged.
        run_name: Name of the MLflow run in which the model was logged.
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.

    Returns:
        MLflow deployment service or None if no service was deployed or found

    Raises:
        ValueError: if the MLflow experiment tracker is not found
    """
    model_deployer = cast(
        MLFlowModelDeployer, MLFlowModelDeployer.get_active_model_deployer()
    )

    experiment_tracker = Client().active_stack.experiment_tracker

    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise ValueError(
            "MLflow model deployer step requires an MLflow experiment "
            "tracker. Please add an MLflow experiment tracker to your "
            "stack."
        )

    # get pipeline name, step name and run id
    step_context = get_step_context()
    pipeline_name = step_context.pipeline.name
    run_name = step_context.pipeline_run.name
    step_name = step_context.step_run.name

    # Configure Mlflow so the client points to the correct store
    experiment_tracker.configure_mlflow()
    client = MlflowClient()
    mlflow_run_id = experiment_tracker.get_run_id(
        experiment_name=experiment_name or pipeline_name,
        run_name=run_name or run_name,
    )

    # Fetch the model URI from the MLflow artifact store
    model_uri = ""
    if mlflow_run_id and client.list_artifacts(mlflow_run_id, model_name):
        model_uri = artifact_utils.get_artifact_uri(  # type: ignore[no-untyped-call]
            run_id=mlflow_run_id, artifact_path=model_name
        )

    predictor_cfg = MLFlowDeploymentConfig(
        model_name=model_name or "",
        model_uri=model_uri,
        workers=workers,
        mlserver=mlserver,
        pipeline_name=pipeline_name,
        pipeline_step_name=step_name,
        timeout=timeout,
    )

    # Fetch existing services with same pipeline name, step name and model name
    existing_services = model_deployer.find_model_server(
        config=predictor_cfg.model_dump(),
    )

    # Check whether to deploy a new service
    if model_uri and deploy_decision:
        new_service = cast(
            MLFlowDeploymentService,
            model_deployer.deploy_model(
                replace=True,
                config=predictor_cfg,
                timeout=timeout,
                service_type=MLFlowDeploymentService.SERVICE_TYPE,
            ),
        )
        logger.info(
            f"MLflow deployment service started and reachable at:\n"
            f"    {new_service.prediction_url}\n"
        )
        logger.info("Stopping existing services...")
        for existing_service in existing_services:
            existing_service.stop(timeout)
        return new_service

    # Check whether an existing service can be reused
    elif existing_services:
        if not model_uri:
            logger.info(
                f"Skipping model deployment because an MLflow model with name "
                f" `{model_name}` was not trained in the current pipeline run."
            )
        elif not deploy_decision:
            logger.info(
                "Skipping model deployment because the deployment decision "
                "was negative."
            )
        logger.info(
            f"Reusing last model server deployed by step `{step_name}` of "
            f"pipeline '{pipeline_name}' for model '{model_name}'..."
        )
        service = cast(MLFlowDeploymentService, existing_services[0])
        if not service.is_running:
            service.start(timeout)
        return service

    # Log a warning if no service was deployed or found
    elif deploy_decision and not model_uri:
        logger.warning(
            f"An MLflow model with name `{model_name}` was not "
            f"logged in the current pipeline run and no running MLflow "
            f"model server was found. Please ensure that your pipeline "
            f"includes a step with an MLflow experiment configured that "
            "trains a model and logs it to MLflow. This could also happen "
            "if the current pipeline run did not log an MLflow model  "
            f"because the training step was cached."
        )
    return None
mlflow_model_registry_deployer_step(registry_model_name: str, registry_model_version: Optional[str] = None, registry_model_stage: Optional[ModelVersionStage] = None, replace_existing: bool = True, model_name: str = 'model', workers: int = 1, mlserver: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT) -> MLFlowDeploymentService

Model deployer pipeline step for MLflow.

Parameters:

Name Type Description Default
registry_model_name str

the name of the model in the model registry

required
registry_model_version Optional[str]

the version of the model in the model registry

None
registry_model_stage Optional[ModelVersionStage]

the stage of the model in the model registry

None
replace_existing bool

Whether to create a new deployment service or not

True
model_name str

the name of the MLflow model logged in the MLflow artifact store for the current pipeline.

'model'
workers int

number of workers to use for the prediction service

1
mlserver bool

set to True to use the MLflow MLServer backend (see https://github.com/SeldonIO/MLServer). If False, the MLflow built-in scoring server will be used.

False
timeout int

the number of seconds to wait for the service to start/stop.

DEFAULT_SERVICE_START_STOP_TIMEOUT

Returns:

Type Description
MLFlowDeploymentService

MLflow deployment service

Raises:

Type Description
ValueError

if neither registry_model_version nor registry_model_stage is not provided

ValueError

if No MLflow experiment tracker is found in the current active stack

LookupError

if no model version is found in the MLflow model registry.

Source code in src/zenml/integrations/mlflow/steps/mlflow_deployer.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
@step(enable_cache=False)
def mlflow_model_registry_deployer_step(
    registry_model_name: str,
    registry_model_version: Optional[str] = None,
    registry_model_stage: Optional[ModelVersionStage] = None,
    replace_existing: bool = True,
    model_name: str = "model",
    workers: int = 1,
    mlserver: bool = False,
    timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT,
) -> MLFlowDeploymentService:
    """Model deployer pipeline step for MLflow.

    Args:
        registry_model_name: the name of the model in the model registry
        registry_model_version: the version of the model in the model registry
        registry_model_stage: the stage of the model in the model registry
        replace_existing: Whether to create a new deployment service or not
        model_name: the name of the MLflow model logged in the MLflow artifact
            store for the current pipeline.
        workers: number of workers to use for the prediction service
        mlserver: set to True to use the MLflow MLServer backend (see
            https://github.com/SeldonIO/MLServer). If False, the
            MLflow built-in scoring server will be used.
        timeout: the number of seconds to wait for the service to start/stop.

    Returns:
        MLflow deployment service

    Raises:
        ValueError: if neither registry_model_version nor registry_model_stage
            is not provided
        ValueError: if No MLflow experiment tracker is found in the current
            active stack
        LookupError: if no model version is found in the MLflow model registry.
    """
    if not registry_model_version and not registry_model_stage:
        raise ValueError(
            "Either registry_model_version or registry_model_stage must "
            "be provided in addition to registry_model_name to the MLflow "
            "model registry deployer step. Since the "
            "mlflow_model_registry_deployer_step is used in conjunction with "
            "the mlflow_model_registry."
        )

    model_deployer = cast(
        MLFlowModelDeployer, MLFlowModelDeployer.get_active_model_deployer()
    )

    # fetch the MLflow model registry
    model_registry = Client().active_stack.model_registry
    if not isinstance(model_registry, MLFlowModelRegistry):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow model registry."
        )

    # fetch the model version
    model_version = None
    if registry_model_version:
        try:
            model_version = model_registry.get_model_version(
                name=registry_model_name,
                version=registry_model_version,
            )
        except KeyError:
            model_version = None
    elif registry_model_stage:
        model_version = model_registry.get_latest_model_version(
            name=registry_model_name,
            stage=registry_model_stage,
        )
    if not model_version:
        raise LookupError(
            f"No Model Version found for model name "
            f"{registry_model_name} and version "
            f"{registry_model_version} or stage "
            f"{registry_model_stage}"
        )
    if model_version.model_format != MLFLOW_MODEL_FORMAT:
        raise ValueError(
            f"Model version {model_version.version} of model "
            f"{model_version.registered_model.name} is not an MLflow model."
            f"Only MLflow models can be deployed with the MLflow deployer "
            f"using this step."
        )
    # fetch existing services with same pipeline name, step name and model name

    existing_services = (
        model_deployer.find_model_server(
            model_name=registry_model_name,
            model_version=model_version.version,
        )
        if replace_existing
        else None
    )
    # create a config for the new model service
    metadata = model_version.metadata or ModelRegistryModelMetadata()
    predictor_cfg = MLFlowDeploymentConfig(
        name=model_name or None,
        model_name=registry_model_name,
        model_version=model_version.version,
        model_uri=model_version.model_source_uri,
        workers=workers,
        mlserver=mlserver,
        pipeline_name=metadata.zenml_pipeline_name or "",
        pipeline_step_name=metadata.zenml_step_name or "",
        timeout=timeout,
    )

    # create a new model deployment and replace an old one if it exists
    new_service = cast(
        MLFlowDeploymentService,
        model_deployer.deploy_model(
            replace=True,
            config=predictor_cfg,
            timeout=timeout,
            service_type=MLFlowDeploymentService.SERVICE_TYPE,
        ),
    )

    logger.info(
        f"MLflow deployment service started and reachable at:\n"
        f"    {new_service.prediction_url}\n"
    )

    if existing_services:
        logger.info("Stopping existing services...")
        for existing_service in existing_services:
            existing_service.stop(timeout)

    return new_service
mlflow_registry

Implementation of the MLflow model registration pipeline step.

Classes Functions
mlflow_register_model_step(model: UnmaterializedArtifact, name: str, version: Optional[str] = None, trained_model_name: Optional[str] = 'model', model_source_uri: Optional[str] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, run_id: Optional[str] = None, description: Optional[str] = None, metadata: Optional[ModelRegistryModelMetadata] = None) -> None

MLflow model registry step.

Parameters:

Name Type Description Default
model UnmaterializedArtifact

Model to be registered, This is not used in the step, but is required to trigger the step when the model is trained.

required
name str

The name of the model.

required
version Optional[str]

The version of the model.

None
trained_model_name Optional[str]

Name of the model artifact in MLflow.

'model'
model_source_uri Optional[str]

The path to the model. If not provided, the model will be fetched from the MLflow tracking server via the trained_model_name.

None
experiment_name Optional[str]

Name of the experiment to be used for the run.

None
run_name Optional[str]

Name of the run to be created.

None
run_id Optional[str]

ID of the run to be used.

None
description Optional[str]

A description of the model version.

None
metadata Optional[ModelRegistryModelMetadata]

A list of metadata to associate with the model version

None

Raises:

Type Description
ValueError

If the model registry is not an MLflow model registry.

ValueError

If the experiment tracker is not an MLflow experiment tracker.

RuntimeError

If no model source URI is provided and no model is found.

RuntimeError

If no run ID is provided and no run is found.

Source code in src/zenml/integrations/mlflow/steps/mlflow_registry.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
@step(enable_cache=True)
def mlflow_register_model_step(
    model: UnmaterializedArtifact,
    name: str,
    version: Optional[str] = None,
    trained_model_name: Optional[str] = "model",
    model_source_uri: Optional[str] = None,
    experiment_name: Optional[str] = None,
    run_name: Optional[str] = None,
    run_id: Optional[str] = None,
    description: Optional[str] = None,
    metadata: Optional[ModelRegistryModelMetadata] = None,
) -> None:
    """MLflow model registry step.

    Args:
        model: Model to be registered, This is not used in the step, but is
            required to trigger the step when the model is trained.
        name: The name of the model.
        version: The version of the model.
        trained_model_name: Name of the model artifact in MLflow.
        model_source_uri: The path to the model. If not provided, the model will
            be fetched from the MLflow tracking server via the
            `trained_model_name`.
        experiment_name: Name of the experiment to be used for the run.
        run_name: Name of the run to be created.
        run_id: ID of the run to be used.
        description: A description of the model version.
        metadata: A list of metadata to associate with the model version

    Raises:
        ValueError: If the model registry is not an MLflow model registry.
        ValueError: If the experiment tracker is not an MLflow experiment tracker.
        RuntimeError: If no model source URI is provided and no model is found.
        RuntimeError: If no run ID is provided and no run is found.
    """
    # get the experiment tracker and check if it is an MLflow experiment tracker.
    experiment_tracker = Client().active_stack.experiment_tracker
    if not isinstance(experiment_tracker, MLFlowExperimentTracker):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow experiment tracker. Please add an MLflow experiment "
            "tracker to your stack."
        )

    # fetch the MLflow model registry
    model_registry = Client().active_stack.model_registry
    if not isinstance(model_registry, MLFlowModelRegistry):
        raise ValueError(
            "The MLflow model registry step can only be used with an "
            "MLflow model registry."
        )

    # get pipeline name, step name and run id
    step_context = get_step_context()
    pipeline_name = step_context.pipeline.name
    current_run_name = step_context.pipeline_run.name
    pipeline_run_uuid = str(step_context.pipeline_run.id)
    zenml_project = str(step_context.pipeline.project.name)

    # Get MLflow run ID either from params or from experiment tracker using
    # pipeline name and run name
    mlflow_run_id = run_id or experiment_tracker.get_run_id(
        experiment_name=experiment_name or pipeline_name,
        run_name=run_name or current_run_name,
    )
    # If no value was set at all, raise an error
    if not mlflow_run_id:
        raise RuntimeError(
            f"Could not find MLflow run for experiment {pipeline_name} "
            f"and run {run_name}."
        )

    # Get MLflow client
    client = model_registry.mlflow_client
    # Lastly, check if the run ID is valid
    try:
        client.get_run(run_id=mlflow_run_id).info.run_id
    except Exception:
        raise RuntimeError(
            f"Could not find MLflow run with ID {mlflow_run_id}."
        )

    # Set model source URI
    model_source_uri = model_source_uri or None

    # Check if the run ID have a model artifact if no model source URI is set.
    if not model_source_uri and client.list_artifacts(
        mlflow_run_id, trained_model_name
    ):
        model_source_uri = artifact_utils.get_artifact_uri(  # type: ignore[no-untyped-call]
            run_id=mlflow_run_id, artifact_path=trained_model_name
        )
    if not model_source_uri:
        raise RuntimeError(
            "No model source URI provided or no model found in the "
            "MLflow tracking server for the given inputs."
        )

    # Check metadata
    if not metadata:
        metadata = ModelRegistryModelMetadata()
    if metadata.zenml_version is None:
        metadata.zenml_version = __version__
    if metadata.zenml_pipeline_name is None:
        metadata.zenml_pipeline_name = pipeline_name
    if metadata.zenml_run_name is None:
        metadata.zenml_run_name = run_name
    if metadata.zenml_pipeline_run_uuid is None:
        metadata.zenml_pipeline_run_uuid = pipeline_run_uuid
    if metadata.zenml_project is None:
        metadata.zenml_project = zenml_project
    if getattr(metadata, "mlflow_run_id", None) is None:
        setattr(metadata, "mlflow_run_id", mlflow_run_id)

    # Register model version
    model_version = model_registry.register_model_version(
        name=name,
        version=version or "1",
        model_source_uri=model_source_uri,
        description=description,
        metadata=metadata,
    )

    logger.info(
        f"Registered model {name} "
        f"with version {model_version.version} "
        f"from source {model_source_uri}."
    )